1/* Halide.h -- interface for the 'Halide' library.
2
3 Copyright (c) 2012-2020 MIT CSAIL, Google, Facebook, Adobe, NVIDIA CORPORATION, and other contributors.
4
5 Developed by:
6
7 The Halide team
8 http://halide-lang.org
9
10 Permission is hereby granted, free of charge, to any person obtaining a copy of
11 this software and associated documentation files (the "Software"), to deal in
12 the Software without restriction, including without limitation the rights to
13 use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies
14 of the Software, and to permit persons to whom the Software is furnished to do
15 so, subject to the following conditions:
16
17 The above copyright notice and this permission notice shall be included in all
18 copies or substantial portions of the Software.
19
20 THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
21 IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
22 FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
23 AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
24 LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
25 OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
26 SOFTWARE.
27
28 -----
29
30 apps/bgu is Copyright 2016 Google Inc. and is Licensed under the Apache License,
31 Version 2.0 (the "License"); you may not use this file except in compliance
32 with the License. You may obtain a copy of the License at
33
34 http ://www.apache.org/licenses/LICENSE-2.0
35
36 Unless required by applicable law or agreed to in writing, software
37 distributed under the License is distributed on an "AS IS" BASIS,
38 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
39 See the License for the specific language governing permissions and
40 limitations under the License.
41
42 -----
43
44 apps/support/cmdline.h is Copyright (c) 2009, Hideyuki Tanaka and is licensed
45 under the BSD 3-Clause license.
46
47 Redistribution and use in source and binary forms, with or without
48 modification, are permitted provided that the following conditions are met:
49 * Redistributions of source code must retain the above copyright
50 notice, this list of conditions and the following disclaimer.
51 * Redistributions in binary form must reproduce the above copyright
52 notice, this list of conditions and the following disclaimer in the
53 documentation and/or other materials provided with the distribution.
54 * Neither the name of the <organization> nor the
55 names of its contributors may be used to endorse or promote products
56 derived from this software without specific prior written permission.
57
58 THIS SOFTWARE IS PROVIDED BY <copyright holder> ''AS IS'' AND ANY
59 EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
60 WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
61 DISCLAIMED. IN NO EVENT SHALL <copyright holder> BE LIABLE FOR ANY
62 DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
63 (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
64 LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
65 ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
66 (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
67 SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
68
69*/
70
71#ifndef HALIDE_H
72#define HALIDE_H
73
74#ifndef HALIDE_ADD_ATOMIC_MUTEX_H
75#define HALIDE_ADD_ATOMIC_MUTEX_H
76
77#ifndef HALIDE_EXPR_H
78#define HALIDE_EXPR_H
79
80/** \file
81 * Base classes for Halide expressions (\ref Halide::Expr) and statements (\ref Halide::Internal::Stmt)
82 */
83
84#include <string>
85#include <vector>
86
87#ifndef HALIDE_INTRUSIVE_PTR_H
88#define HALIDE_INTRUSIVE_PTR_H
89
90/** \file
91 *
92 * Support classes for reference-counting via intrusive shared
93 * pointers.
94 */
95
96#include <atomic>
97#include <cstdlib>
98
99#ifndef HALIDE_HALIDERUNTIME_H
100#define HALIDE_HALIDERUNTIME_H
101
102#ifndef COMPILING_HALIDE_RUNTIME
103#ifdef __cplusplus
104#include <cstddef>
105#include <cstdint>
106#include <cstring>
107#else
108#include <stdbool.h>
109#include <stddef.h>
110#include <stdint.h>
111#include <string.h>
112#endif
113#else
114#error "COMPILING_HALIDE_RUNTIME should never be defined for Halide.h"
115#endif
116
117#ifdef __cplusplus
118// Forward declare type to allow naming typed handles.
119// See Type.h for documentation.
120template<typename T>
121struct halide_handle_traits;
122#endif
123
124#ifdef __cplusplus
125extern "C" {
126#endif
127
128#ifdef _MSC_VER
129// Note that (for MSVC) you should not use "inline" along with HALIDE_ALWAYS_INLINE;
130// it is not necessary, and may produce warnings for some build configurations.
131#define HALIDE_ALWAYS_INLINE __forceinline
132#define HALIDE_NEVER_INLINE __declspec(noinline)
133#else
134// Note that (for Posixy compilers) you should always use "inline" along with HALIDE_ALWAYS_INLINE;
135// otherwise some corner-case scenarios may erroneously report link errors.
136#define HALIDE_ALWAYS_INLINE inline __attribute__((always_inline))
137#define HALIDE_NEVER_INLINE __attribute__((noinline))
138#endif
139
140#ifndef HALIDE_MUST_USE_RESULT
141#ifdef __has_attribute
142#if __has_attribute(nodiscard)
143// C++17 or later
144#define HALIDE_MUST_USE_RESULT [[nodiscard]]
145#elif __has_attribute(warn_unused_result)
146// Clang/GCC
147#define HALIDE_MUST_USE_RESULT __attribute__((warn_unused_result))
148#else
149#define HALIDE_MUST_USE_RESULT
150#endif
151#else
152#define HALIDE_MUST_USE_RESULT
153#endif
154#endif
155
156/** \file
157 *
158 * This file declares the routines used by Halide internally in its
159 * runtime. On platforms that support weak linking, these can be
160 * replaced with user-defined versions by defining an extern "C"
161 * function with the same name and signature.
162 *
163 * When doing Just In Time (JIT) compilation methods on the Func being
164 * compiled must be called instead. The corresponding methods are
165 * documented below.
166 *
167 * All of these functions take a "void *user_context" parameter as their
168 * first argument; if the Halide kernel that calls back to any of these
169 * functions has been compiled with the UserContext feature set on its Target,
170 * then the value of that pointer passed from the code that calls the
171 * Halide kernel is piped through to the function.
172 *
173 * Some of these are also useful to call when using the default
174 * implementation. E.g. halide_shutdown_thread_pool.
175 *
176 * Note that even on platforms with weak linking, some linker setups
177 * may not respect the override you provide. E.g. if the override is
178 * in a shared library and the halide object files are linked directly
179 * into the output, the builtin versions of the runtime functions will
180 * be called. See your linker documentation for more details. On
181 * Linux, LD_DYNAMIC_WEAK=1 may help.
182 *
183 */
184
185// Forward-declare to suppress warnings if compiling as C.
186struct halide_buffer_t;
187
188/** Print a message to stderr. Main use is to support tracing
189 * functionality, print, and print_when calls. Also called by the default
190 * halide_error. This function can be replaced in JITed code by using
191 * halide_custom_print and providing an implementation of halide_print
192 * in AOT code. See Func::set_custom_print.
193 */
194// @{
195extern void halide_print(void *user_context, const char *);
196extern void halide_default_print(void *user_context, const char *);
197typedef void (*halide_print_t)(void *, const char *);
198extern halide_print_t halide_set_custom_print(halide_print_t print);
199// @}
200
201/** Halide calls this function on runtime errors (for example bounds
202 * checking failures). This function can be replaced in JITed code by
203 * using Func::set_error_handler, or in AOT code by calling
204 * halide_set_error_handler. In AOT code on platforms that support
205 * weak linking (i.e. not Windows), you can also override it by simply
206 * defining your own halide_error.
207 */
208// @{
209extern void halide_error(void *user_context, const char *);
210extern void halide_default_error(void *user_context, const char *);
211typedef void (*halide_error_handler_t)(void *, const char *);
212extern halide_error_handler_t halide_set_error_handler(halide_error_handler_t handler);
213// @}
214
215/** Cross-platform mutex. Must be initialized with zero and implementation
216 * must treat zero as an unlocked mutex with no waiters, etc.
217 */
218struct halide_mutex {
219 uintptr_t _private[1];
220};
221
222/** Cross platform condition variable. Must be initialized to 0. */
223struct halide_cond {
224 uintptr_t _private[1];
225};
226
227/** A basic set of mutex and condition variable functions, which call
228 * platform specific code for mutual exclusion. Equivalent to posix
229 * calls. */
230//@{
231extern void halide_mutex_lock(struct halide_mutex *mutex);
232extern void halide_mutex_unlock(struct halide_mutex *mutex);
233extern void halide_cond_signal(struct halide_cond *cond);
234extern void halide_cond_broadcast(struct halide_cond *cond);
235extern void halide_cond_wait(struct halide_cond *cond, struct halide_mutex *mutex);
236//@}
237
238/** Functions for constructing/destroying/locking/unlocking arrays of mutexes. */
239struct halide_mutex_array;
240//@{
241extern struct halide_mutex_array *halide_mutex_array_create(int sz);
242extern void halide_mutex_array_destroy(void *user_context, void *array);
243extern int halide_mutex_array_lock(struct halide_mutex_array *array, int entry);
244extern int halide_mutex_array_unlock(struct halide_mutex_array *array, int entry);
245//@}
246
247/** Define halide_do_par_for to replace the default thread pool
248 * implementation. halide_shutdown_thread_pool can also be called to
249 * release resources used by the default thread pool on platforms
250 * where it makes sense. See Func::set_custom_do_task and
251 * Func::set_custom_do_par_for. Should return zero if all the jobs
252 * return zero, or an arbitrarily chosen return value from one of the
253 * jobs otherwise.
254 */
255//@{
256typedef int (*halide_task_t)(void *user_context, int task_number, uint8_t *closure);
257extern int halide_do_par_for(void *user_context,
258 halide_task_t task,
259 int min, int size, uint8_t *closure);
260extern void halide_shutdown_thread_pool();
261//@}
262
263/** Set a custom method for performing a parallel for loop. Returns
264 * the old do_par_for handler. */
265typedef int (*halide_do_par_for_t)(void *, halide_task_t, int, int, uint8_t *);
266extern halide_do_par_for_t halide_set_custom_do_par_for(halide_do_par_for_t do_par_for);
267
268/** An opaque struct representing a semaphore. Used by the task system for async tasks. */
269struct halide_semaphore_t {
270 uint64_t _private[2];
271};
272
273/** A struct representing a semaphore and a number of items that must
274 * be acquired from it. Used in halide_parallel_task_t below. */
275struct halide_semaphore_acquire_t {
276 struct halide_semaphore_t *semaphore;
277 int count;
278};
279extern int halide_semaphore_init(struct halide_semaphore_t *, int n);
280extern int halide_semaphore_release(struct halide_semaphore_t *, int n);
281extern bool halide_semaphore_try_acquire(struct halide_semaphore_t *, int n);
282typedef int (*halide_semaphore_init_t)(struct halide_semaphore_t *, int);
283typedef int (*halide_semaphore_release_t)(struct halide_semaphore_t *, int);
284typedef bool (*halide_semaphore_try_acquire_t)(struct halide_semaphore_t *, int);
285
286/** A task representing a serial for loop evaluated over some range.
287 * Note that task_parent is a pass through argument that should be
288 * passed to any dependent taks that are invoked using halide_do_parallel_tasks
289 * underneath this call. */
290typedef int (*halide_loop_task_t)(void *user_context, int min, int extent,
291 uint8_t *closure, void *task_parent);
292
293/** A parallel task to be passed to halide_do_parallel_tasks. This
294 * task may recursively call halide_do_parallel_tasks, and there may
295 * be complex dependencies between seemingly unrelated tasks expressed
296 * using semaphores. If you are using a custom task system, care must
297 * be taken to avoid potential deadlock. This can be done by carefully
298 * respecting the static metadata at the end of the task struct.*/
299struct halide_parallel_task_t {
300 // The function to call. It takes a user context, a min and
301 // extent, a closure, and a task system pass through argument.
302 halide_loop_task_t fn;
303
304 // The closure to pass it
305 uint8_t *closure;
306
307 // The name of the function to be called. For debugging purposes only.
308 const char *name;
309
310 // An array of semaphores that must be acquired before the
311 // function is called. Must be reacquired for every call made.
312 struct halide_semaphore_acquire_t *semaphores;
313 int num_semaphores;
314
315 // The entire range the function should be called over. This range
316 // may be sliced up and the function called multiple times.
317 int min, extent;
318
319 // A parallel task provides several pieces of metadata to prevent
320 // unbounded resource usage or deadlock.
321
322 // The first is the minimum number of execution contexts (call
323 // stacks or threads) necessary for the function to run to
324 // completion. This may be greater than one when there is nested
325 // parallelism with internal producer-consumer relationships
326 // (calling the function recursively spawns and blocks on parallel
327 // sub-tasks that communicate with each other via semaphores). If
328 // a parallel runtime calls the function when fewer than this many
329 // threads are idle, it may need to create more threads to
330 // complete the task, or else risk deadlock due to committing all
331 // threads to tasks that cannot complete without more.
332 //
333 // FIXME: Note that extern stages are assumed to only require a
334 // single thread to complete. If the extern stage is itself a
335 // Halide pipeline, this may be an underestimate.
336 int min_threads;
337
338 // The calls to the function should be in serial order from min to min+extent-1, with only
339 // one executing at a time. If false, any order is fine, and
340 // concurrency is fine.
341 bool serial;
342};
343
344/** Enqueue some number of the tasks described above and wait for them
345 * to complete. While waiting, the calling threads assists with either
346 * the tasks enqueued, or other non-blocking tasks in the task
347 * system. Note that task_parent should be NULL for top-level calls
348 * and the pass through argument if this call is being made from
349 * another task. */
350extern int halide_do_parallel_tasks(void *user_context, int num_tasks,
351 struct halide_parallel_task_t *tasks,
352 void *task_parent);
353
354/** If you use the default do_par_for, you can still set a custom
355 * handler to perform each individual task. Returns the old handler. */
356//@{
357typedef int (*halide_do_task_t)(void *, halide_task_t, int, uint8_t *);
358extern halide_do_task_t halide_set_custom_do_task(halide_do_task_t do_task);
359extern int halide_do_task(void *user_context, halide_task_t f, int idx,
360 uint8_t *closure);
361//@}
362
363/** The version of do_task called for loop tasks. By default calls the
364 * loop task with the same arguments. */
365// @{
366typedef int (*halide_do_loop_task_t)(void *, halide_loop_task_t, int, int, uint8_t *, void *);
367extern halide_do_loop_task_t halide_set_custom_do_loop_task(halide_do_loop_task_t do_task);
368extern int halide_do_loop_task(void *user_context, halide_loop_task_t f, int min, int extent,
369 uint8_t *closure, void *task_parent);
370//@}
371
372/** Provide an entire custom tasking runtime via function
373 * pointers. Note that do_task and semaphore_try_acquire are only ever
374 * called by halide_default_do_par_for and
375 * halide_default_do_parallel_tasks, so it's only necessary to provide
376 * those if you are mixing in the default implementations of
377 * do_par_for and do_parallel_tasks. */
378// @{
379typedef int (*halide_do_parallel_tasks_t)(void *, int, struct halide_parallel_task_t *,
380 void *task_parent);
381extern void halide_set_custom_parallel_runtime(
382 halide_do_par_for_t,
383 halide_do_task_t,
384 halide_do_loop_task_t,
385 halide_do_parallel_tasks_t,
386 halide_semaphore_init_t,
387 halide_semaphore_try_acquire_t,
388 halide_semaphore_release_t);
389// @}
390
391/** The default versions of the parallel runtime functions. */
392// @{
393extern int halide_default_do_par_for(void *user_context,
394 halide_task_t task,
395 int min, int size, uint8_t *closure);
396extern int halide_default_do_parallel_tasks(void *user_context,
397 int num_tasks,
398 struct halide_parallel_task_t *tasks,
399 void *task_parent);
400extern int halide_default_do_task(void *user_context, halide_task_t f, int idx,
401 uint8_t *closure);
402extern int halide_default_do_loop_task(void *user_context, halide_loop_task_t f,
403 int min, int extent,
404 uint8_t *closure, void *task_parent);
405extern int halide_default_semaphore_init(struct halide_semaphore_t *, int n);
406extern int halide_default_semaphore_release(struct halide_semaphore_t *, int n);
407extern bool halide_default_semaphore_try_acquire(struct halide_semaphore_t *, int n);
408// @}
409
410struct halide_thread;
411
412/** Spawn a thread. Returns a handle to the thread for the purposes of
413 * joining it. The thread must be joined in order to clean up any
414 * resources associated with it. */
415extern struct halide_thread *halide_spawn_thread(void (*f)(void *), void *closure);
416
417/** Join a thread. */
418extern void halide_join_thread(struct halide_thread *);
419
420/** Set the number of threads used by Halide's thread pool. Returns
421 * the old number.
422 *
423 * n < 0 : error condition
424 * n == 0 : use a reasonable system default (typically, number of cpus online).
425 * n == 1 : use exactly one thread; this will always enforce serial execution
426 * n > 1 : use a pool of exactly n threads.
427 *
428 * (Note that this is only guaranteed when using the default implementations
429 * of halide_do_par_for(); custom implementations may completely ignore values
430 * passed to halide_set_num_threads().)
431 */
432extern int halide_set_num_threads(int n);
433
434/** Halide calls these functions to allocate and free memory. To
435 * replace in AOT code, use the halide_set_custom_malloc and
436 * halide_set_custom_free, or (on platforms that support weak
437 * linking), simply define these functions yourself. In JIT-compiled
438 * code use Func::set_custom_allocator.
439 *
440 * If you override them, and find yourself wanting to call the default
441 * implementation from within your override, use
442 * halide_default_malloc/free.
443 *
444 * Note that halide_malloc must return a pointer aligned to the
445 * maximum meaningful alignment for the platform for the purpose of
446 * vector loads and stores. The default implementation uses 32-byte
447 * alignment, which is safe for arm and x86. Additionally, it must be
448 * safe to read at least 8 bytes before the start and beyond the
449 * end.
450 */
451//@{
452extern void *halide_malloc(void *user_context, size_t x);
453extern void halide_free(void *user_context, void *ptr);
454extern void *halide_default_malloc(void *user_context, size_t x);
455extern void halide_default_free(void *user_context, void *ptr);
456typedef void *(*halide_malloc_t)(void *, size_t);
457typedef void (*halide_free_t)(void *, void *);
458extern halide_malloc_t halide_set_custom_malloc(halide_malloc_t user_malloc);
459extern halide_free_t halide_set_custom_free(halide_free_t user_free);
460//@}
461
462/** Halide calls these functions to interact with the underlying
463 * system runtime functions. To replace in AOT code on platforms that
464 * support weak linking, define these functions yourself, or use
465 * the halide_set_custom_load_library() and halide_set_custom_get_library_symbol()
466 * functions. In JIT-compiled code, use JITSharedRuntime::set_default_handlers().
467 *
468 * halide_load_library and halide_get_library_symbol are equivalent to
469 * dlopen and dlsym. halide_get_symbol(sym) is equivalent to
470 * dlsym(RTLD_DEFAULT, sym).
471 */
472//@{
473extern void *halide_get_symbol(const char *name);
474extern void *halide_load_library(const char *name);
475extern void *halide_get_library_symbol(void *lib, const char *name);
476extern void *halide_default_get_symbol(const char *name);
477extern void *halide_default_load_library(const char *name);
478extern void *halide_default_get_library_symbol(void *lib, const char *name);
479typedef void *(*halide_get_symbol_t)(const char *name);
480typedef void *(*halide_load_library_t)(const char *name);
481typedef void *(*halide_get_library_symbol_t)(void *lib, const char *name);
482extern halide_get_symbol_t halide_set_custom_get_symbol(halide_get_symbol_t user_get_symbol);
483extern halide_load_library_t halide_set_custom_load_library(halide_load_library_t user_load_library);
484extern halide_get_library_symbol_t halide_set_custom_get_library_symbol(halide_get_library_symbol_t user_get_library_symbol);
485//@}
486
487/** Called when debug_to_file is used inside %Halide code. See
488 * Func::debug_to_file for how this is called
489 *
490 * Cannot be replaced in JITted code at present.
491 */
492extern int32_t halide_debug_to_file(void *user_context, const char *filename,
493 int32_t type_code,
494 struct halide_buffer_t *buf);
495
496/** Types in the halide type system. They can be ints, unsigned ints,
497 * or floats (of various bit-widths), or a handle (which is always 64-bits).
498 * Note that the int/uint/float values do not imply a specific bit width
499 * (the bit width is expected to be encoded in a separate value).
500 */
501typedef enum halide_type_code_t
502#if (__cplusplus >= 201103L || _MSVC_LANG >= 201103L)
503 : uint8_t
504#endif
505{
506 halide_type_int = 0, ///< signed integers
507 halide_type_uint = 1, ///< unsigned integers
508 halide_type_float = 2, ///< IEEE floating point numbers
509 halide_type_handle = 3, ///< opaque pointer type (void *)
510 halide_type_bfloat = 4, ///< floating point numbers in the bfloat format
511} halide_type_code_t;
512
513// Note that while __attribute__ can go before or after the declaration,
514// __declspec apparently is only allowed before.
515#ifndef HALIDE_ATTRIBUTE_ALIGN
516#ifdef _MSC_VER
517#define HALIDE_ATTRIBUTE_ALIGN(x) __declspec(align(x))
518#else
519#define HALIDE_ATTRIBUTE_ALIGN(x) __attribute__((aligned(x)))
520#endif
521#endif
522
523/** A runtime tag for a type in the halide type system. Can be ints,
524 * unsigned ints, or floats of various bit-widths (the 'bits'
525 * field). Can also be vectors of the same (by setting the 'lanes'
526 * field to something larger than one). This struct should be
527 * exactly 32-bits in size. */
528struct halide_type_t {
529 /** The basic type code: signed integer, unsigned integer, or floating point. */
530#if (__cplusplus >= 201103L || _MSVC_LANG >= 201103L)
531 HALIDE_ATTRIBUTE_ALIGN(1)
532 halide_type_code_t code; // halide_type_code_t
533#else
534 HALIDE_ATTRIBUTE_ALIGN(1)
535 uint8_t code; // halide_type_code_t
536#endif
537
538 /** The number of bits of precision of a single scalar value of this type. */
539 HALIDE_ATTRIBUTE_ALIGN(1)
540 uint8_t bits;
541
542 /** How many elements in a vector. This is 1 for scalar types. */
543 HALIDE_ATTRIBUTE_ALIGN(2)
544 uint16_t lanes;
545
546#if (__cplusplus >= 201103L || _MSVC_LANG >= 201103L)
547 /** Construct a runtime representation of a Halide type from:
548 * code: The fundamental type from an enum.
549 * bits: The bit size of one element.
550 * lanes: The number of vector elements in the type. */
551 HALIDE_ALWAYS_INLINE halide_type_t(halide_type_code_t code, uint8_t bits, uint16_t lanes = 1)
552 : code(code), bits(bits), lanes(lanes) {
553 }
554
555 /** Default constructor is required e.g. to declare halide_trace_event
556 * instances. */
557 HALIDE_ALWAYS_INLINE halide_type_t()
558 : code((halide_type_code_t)0), bits(0), lanes(0) {
559 }
560
561 HALIDE_ALWAYS_INLINE halide_type_t with_lanes(uint16_t new_lanes) const {
562 return halide_type_t((halide_type_code_t)code, bits, new_lanes);
563 }
564
565 /** Compare two types for equality. */
566 HALIDE_ALWAYS_INLINE bool operator==(const halide_type_t &other) const {
567 return as_u32() == other.as_u32();
568 }
569
570 HALIDE_ALWAYS_INLINE bool operator!=(const halide_type_t &other) const {
571 return !(*this == other);
572 }
573
574 HALIDE_ALWAYS_INLINE bool operator<(const halide_type_t &other) const {
575 return as_u32() < other.as_u32();
576 }
577
578 /** Size in bytes for a single element, even if width is not 1, of this type. */
579 HALIDE_ALWAYS_INLINE int bytes() const {
580 return (bits + 7) / 8;
581 }
582
583 HALIDE_ALWAYS_INLINE uint32_t as_u32() const {
584 uint32_t u;
585 memcpy(&u, this, sizeof(u));
586 return u;
587 }
588#endif
589};
590
591enum halide_trace_event_code_t { halide_trace_load = 0,
592 halide_trace_store = 1,
593 halide_trace_begin_realization = 2,
594 halide_trace_end_realization = 3,
595 halide_trace_produce = 4,
596 halide_trace_end_produce = 5,
597 halide_trace_consume = 6,
598 halide_trace_end_consume = 7,
599 halide_trace_begin_pipeline = 8,
600 halide_trace_end_pipeline = 9,
601 halide_trace_tag = 10 };
602
603struct halide_trace_event_t {
604 /** The name of the Func or Pipeline that this event refers to */
605 const char *func;
606
607 /** If the event type is a load or a store, this points to the
608 * value being loaded or stored. Use the type field to safely cast
609 * this to a concrete pointer type and retrieve it. For other
610 * events this is null. */
611 void *value;
612
613 /** For loads and stores, an array which contains the location
614 * being accessed. For vector loads or stores it is an array of
615 * vectors of coordinates (the vector dimension is innermost).
616 *
617 * For realization or production-related events, this will contain
618 * the mins and extents of the region being accessed, in the order
619 * min0, extent0, min1, extent1, ...
620 *
621 * For pipeline-related events, this will be null.
622 */
623 int32_t *coordinates;
624
625 /** For halide_trace_tag, this points to a read-only null-terminated string
626 * of arbitrary text. For all other events, this will be null.
627 */
628 const char *trace_tag;
629
630 /** If the event type is a load or a store, this is the type of
631 * the data. Otherwise, the value is meaningless. */
632 struct halide_type_t type;
633
634 /** The type of event */
635 enum halide_trace_event_code_t event;
636
637 /* The ID of the parent event (see below for an explanation of
638 * event ancestry). */
639 int32_t parent_id;
640
641 /** If this was a load or store of a Tuple-valued Func, this is
642 * which tuple element was accessed. */
643 int32_t value_index;
644
645 /** The length of the coordinates array */
646 int32_t dimensions;
647
648#if (__cplusplus >= 201103L || _MSVC_LANG >= 201103L)
649 // If we don't explicitly mark the default ctor as inline,
650 // certain build configurations can fail (notably iOS)
651 HALIDE_ALWAYS_INLINE halide_trace_event_t() = default;
652#endif
653};
654
655/** Called when Funcs are marked as trace_load, trace_store, or
656 * trace_realization. See Func::set_custom_trace. The default
657 * implementation either prints events via halide_print, or if
658 * HL_TRACE_FILE is defined, dumps the trace to that file in a
659 * sequence of trace packets. The header for a trace packet is defined
660 * below. If the trace is going to be large, you may want to make the
661 * file a named pipe, and then read from that pipe into gzip.
662 *
663 * halide_trace returns a unique ID which will be passed to future
664 * events that "belong" to the earlier event as the parent id. The
665 * ownership hierarchy looks like:
666 *
667 * begin_pipeline
668 * +--trace_tag (if any)
669 * +--trace_tag (if any)
670 * ...
671 * +--begin_realization
672 * | +--produce
673 * | | +--load/store
674 * | | +--end_produce
675 * | +--consume
676 * | | +--load
677 * | | +--end_consume
678 * | +--end_realization
679 * +--end_pipeline
680 *
681 * Threading means that ownership cannot be inferred from the ordering
682 * of events. There can be many active realizations of a given
683 * function, or many active productions for a single
684 * realization. Within a single production, the ordering of events is
685 * meaningful.
686 *
687 * Note that all trace_tag events (if any) will occur just after the begin_pipeline
688 * event, but before any begin_realization events. All trace_tags for a given Func
689 * will be emitted in the order added.
690 */
691// @}
692extern int32_t halide_trace(void *user_context, const struct halide_trace_event_t *event);
693extern int32_t halide_default_trace(void *user_context, const struct halide_trace_event_t *event);
694typedef int32_t (*halide_trace_t)(void *user_context, const struct halide_trace_event_t *);
695extern halide_trace_t halide_set_custom_trace(halide_trace_t trace);
696// @}
697
698/** The header of a packet in a binary trace. All fields are 32-bit. */
699struct halide_trace_packet_t {
700 /** The total size of this packet in bytes. Always a multiple of
701 * four. Equivalently, the number of bytes until the next
702 * packet. */
703 uint32_t size;
704
705 /** The id of this packet (for the purpose of parent_id). */
706 int32_t id;
707
708 /** The remaining fields are equivalent to those in halide_trace_event_t */
709 // @{
710 struct halide_type_t type;
711 enum halide_trace_event_code_t event;
712 int32_t parent_id;
713 int32_t value_index;
714 int32_t dimensions;
715 // @}
716
717#if (__cplusplus >= 201103L || _MSVC_LANG >= 201103L)
718 // If we don't explicitly mark the default ctor as inline,
719 // certain build configurations can fail (notably iOS)
720 HALIDE_ALWAYS_INLINE halide_trace_packet_t() = default;
721
722 /** Get the coordinates array, assuming this packet is laid out in
723 * memory as it was written. The coordinates array comes
724 * immediately after the packet header. */
725 HALIDE_ALWAYS_INLINE const int *coordinates() const {
726 return (const int *)(this + 1);
727 }
728
729 HALIDE_ALWAYS_INLINE int *coordinates() {
730 return (int *)(this + 1);
731 }
732
733 /** Get the value, assuming this packet is laid out in memory as
734 * it was written. The packet comes immediately after the coordinates
735 * array. */
736 HALIDE_ALWAYS_INLINE const void *value() const {
737 return (const void *)(coordinates() + dimensions);
738 }
739
740 HALIDE_ALWAYS_INLINE void *value() {
741 return (void *)(coordinates() + dimensions);
742 }
743
744 /** Get the func name, assuming this packet is laid out in memory
745 * as it was written. It comes after the value. */
746 HALIDE_ALWAYS_INLINE const char *func() const {
747 return (const char *)value() + type.lanes * type.bytes();
748 }
749
750 HALIDE_ALWAYS_INLINE char *func() {
751 return (char *)value() + type.lanes * type.bytes();
752 }
753
754 /** Get the trace_tag (if any), assuming this packet is laid out in memory
755 * as it was written. It comes after the func name. If there is no trace_tag,
756 * this will return a pointer to an empty string. */
757 HALIDE_ALWAYS_INLINE const char *trace_tag() const {
758 const char *f = func();
759 // strlen may not be available here
760 while (*f++) {
761 // nothing
762 }
763 return f;
764 }
765
766 HALIDE_ALWAYS_INLINE char *trace_tag() {
767 char *f = func();
768 // strlen may not be available here
769 while (*f++) {
770 // nothing
771 }
772 return f;
773 }
774#endif
775};
776
777/** Set the file descriptor that Halide should write binary trace
778 * events to. If called with 0 as the argument, Halide outputs trace
779 * information to stdout in a human-readable format. If never called,
780 * Halide checks the for existence of an environment variable called
781 * HL_TRACE_FILE and opens that file. If HL_TRACE_FILE is not defined,
782 * it outputs trace information to stdout in a human-readable
783 * format. */
784extern void halide_set_trace_file(int fd);
785
786/** Halide calls this to retrieve the file descriptor to write binary
787 * trace events to. The default implementation returns the value set
788 * by halide_set_trace_file. Implement it yourself if you wish to use
789 * a custom file descriptor per user_context. Return zero from your
790 * implementation to tell Halide to print human-readable trace
791 * information to stdout. */
792extern int halide_get_trace_file(void *user_context);
793
794/** If tracing is writing to a file. This call closes that file
795 * (flushing the trace). Returns zero on success. */
796extern int halide_shutdown_trace();
797
798/** All Halide GPU or device backend implementations provide an
799 * interface to be used with halide_device_malloc, etc. This is
800 * accessed via the functions below.
801 */
802
803/** An opaque struct containing per-GPU API implementations of the
804 * device functions. */
805struct halide_device_interface_impl_t;
806
807/** Each GPU API provides a halide_device_interface_t struct pointing
808 * to the code that manages device allocations. You can access these
809 * functions directly from the struct member function pointers, or by
810 * calling the functions declared below. Note that the global
811 * functions are not available when using Halide as a JIT compiler.
812 * If you are using raw halide_buffer_t in that context you must use
813 * the function pointers in the device_interface struct.
814 *
815 * The function pointers below are currently the same for every GPU
816 * API; only the impl field varies. These top-level functions do the
817 * bookkeeping that is common across all GPU APIs, and then dispatch
818 * to more API-specific functions via another set of function pointers
819 * hidden inside the impl field.
820 */
821struct halide_device_interface_t {
822 int (*device_malloc)(void *user_context, struct halide_buffer_t *buf,
823 const struct halide_device_interface_t *device_interface);
824 int (*device_free)(void *user_context, struct halide_buffer_t *buf);
825 int (*device_sync)(void *user_context, struct halide_buffer_t *buf);
826 void (*device_release)(void *user_context,
827 const struct halide_device_interface_t *device_interface);
828 int (*copy_to_host)(void *user_context, struct halide_buffer_t *buf);
829 int (*copy_to_device)(void *user_context, struct halide_buffer_t *buf,
830 const struct halide_device_interface_t *device_interface);
831 int (*device_and_host_malloc)(void *user_context, struct halide_buffer_t *buf,
832 const struct halide_device_interface_t *device_interface);
833 int (*device_and_host_free)(void *user_context, struct halide_buffer_t *buf);
834 int (*buffer_copy)(void *user_context, struct halide_buffer_t *src,
835 const struct halide_device_interface_t *dst_device_interface, struct halide_buffer_t *dst);
836 int (*device_crop)(void *user_context, const struct halide_buffer_t *src,
837 struct halide_buffer_t *dst);
838 int (*device_slice)(void *user_context, const struct halide_buffer_t *src,
839 int slice_dim, int slice_pos, struct halide_buffer_t *dst);
840 int (*device_release_crop)(void *user_context, struct halide_buffer_t *buf);
841 int (*wrap_native)(void *user_context, struct halide_buffer_t *buf, uint64_t handle,
842 const struct halide_device_interface_t *device_interface);
843 int (*detach_native)(void *user_context, struct halide_buffer_t *buf);
844 int (*compute_capability)(void *user_context, int *major, int *minor);
845 const struct halide_device_interface_impl_t *impl;
846};
847
848/** Release all data associated with the given device interface, in
849 * particular all resources (memory, texture, context handles)
850 * allocated by Halide. Must be called explicitly when using AOT
851 * compilation. This is *not* thread-safe with respect to actively
852 * running Halide code. Ensure all pipelines are finished before
853 * calling this. */
854extern void halide_device_release(void *user_context,
855 const struct halide_device_interface_t *device_interface);
856
857/** Copy image data from device memory to host memory. This must be called
858 * explicitly to copy back the results of a GPU-based filter. */
859extern int halide_copy_to_host(void *user_context, struct halide_buffer_t *buf);
860
861/** Copy image data from host memory to device memory. This should not
862 * be called directly; Halide handles copying to the device
863 * automatically. If interface is NULL and the buf has a non-zero dev
864 * field, the device associated with the dev handle will be
865 * used. Otherwise if the dev field is 0 and interface is NULL, an
866 * error is returned. */
867extern int halide_copy_to_device(void *user_context, struct halide_buffer_t *buf,
868 const struct halide_device_interface_t *device_interface);
869
870/** Copy data from one buffer to another. The buffers may have
871 * different shapes and sizes, but the destination buffer's shape must
872 * be contained within the source buffer's shape. That is, for each
873 * dimension, the min on the destination buffer must be greater than
874 * or equal to the min on the source buffer, and min+extent on the
875 * destination buffer must be less that or equal to min+extent on the
876 * source buffer. The source data is pulled from either device or
877 * host memory on the source, depending on the dirty flags. host is
878 * preferred if both are valid. The dst_device_interface parameter
879 * controls the destination memory space. NULL means host memory. */
880extern int halide_buffer_copy(void *user_context, struct halide_buffer_t *src,
881 const struct halide_device_interface_t *dst_device_interface,
882 struct halide_buffer_t *dst);
883
884/** Give the destination buffer a device allocation which is an alias
885 * for the same coordinate range in the source buffer. Modifies the
886 * device, device_interface, and the device_dirty flag only. Only
887 * supported by some device APIs (others will return
888 * halide_error_code_device_crop_unsupported). Call
889 * halide_device_release_crop instead of halide_device_free to clean
890 * up resources associated with the cropped view. Do not free the
891 * device allocation on the source buffer while the destination buffer
892 * still lives. Note that the two buffers do not share dirty flags, so
893 * care must be taken to update them together as needed. Note that src
894 * and dst are required to have the same number of dimensions.
895 *
896 * Note also that (in theory) device interfaces which support cropping may
897 * still not support cropping a crop (instead, create a new crop of the parent
898 * buffer); in practice, no known implementation has this limitation, although
899 * it is possible that some future implementations may require it. */
900extern int halide_device_crop(void *user_context,
901 const struct halide_buffer_t *src,
902 struct halide_buffer_t *dst);
903
904/** Give the destination buffer a device allocation which is an alias
905 * for a similar coordinate range in the source buffer, but with one dimension
906 * sliced away in the dst. Modifies the device, device_interface, and the
907 * device_dirty flag only. Only supported by some device APIs (others will return
908 * halide_error_code_device_crop_unsupported). Call
909 * halide_device_release_crop instead of halide_device_free to clean
910 * up resources associated with the sliced view. Do not free the
911 * device allocation on the source buffer while the destination buffer
912 * still lives. Note that the two buffers do not share dirty flags, so
913 * care must be taken to update them together as needed. Note that the dst buffer
914 * must have exactly one fewer dimension than the src buffer, and that slice_dim
915 * and slice_pos must be valid within src. */
916extern int halide_device_slice(void *user_context,
917 const struct halide_buffer_t *src,
918 int slice_dim, int slice_pos,
919 struct halide_buffer_t *dst);
920
921/** Release any resources associated with a cropped/sliced view of another
922 * buffer. */
923extern int halide_device_release_crop(void *user_context,
924 struct halide_buffer_t *buf);
925
926/** Wait for current GPU operations to complete. Calling this explicitly
927 * should rarely be necessary, except maybe for profiling. */
928extern int halide_device_sync(void *user_context, struct halide_buffer_t *buf);
929
930/** Allocate device memory to back a halide_buffer_t. */
931extern int halide_device_malloc(void *user_context, struct halide_buffer_t *buf,
932 const struct halide_device_interface_t *device_interface);
933
934/** Free device memory. */
935extern int halide_device_free(void *user_context, struct halide_buffer_t *buf);
936
937/** Wrap or detach a native device handle, setting the device field
938 * and device_interface field as appropriate for the given GPU
939 * API. The meaning of the opaque handle is specific to the device
940 * interface, so if you know the device interface in use, call the
941 * more specific functions in the runtime headers for your specific
942 * device API instead (e.g. HalideRuntimeCuda.h). */
943// @{
944extern int halide_device_wrap_native(void *user_context,
945 struct halide_buffer_t *buf,
946 uint64_t handle,
947 const struct halide_device_interface_t *device_interface);
948extern int halide_device_detach_native(void *user_context, struct halide_buffer_t *buf);
949// @}
950
951/** Selects which gpu device to use. 0 is usually the display
952 * device. If never called, Halide uses the environment variable
953 * HL_GPU_DEVICE. If that variable is unset, Halide uses the last
954 * device. Set this to -1 to use the last device. */
955extern void halide_set_gpu_device(int n);
956
957/** Halide calls this to get the desired halide gpu device
958 * setting. Implement this yourself to use a different gpu device per
959 * user_context. The default implementation returns the value set by
960 * halide_set_gpu_device, or the environment variable
961 * HL_GPU_DEVICE. */
962extern int halide_get_gpu_device(void *user_context);
963
964/** Set the soft maximum amount of memory, in bytes, that the LRU
965 * cache will use to memoize Func results. This is not a strict
966 * maximum in that concurrency and simultaneous use of memoized
967 * reults larger than the cache size can both cause it to
968 * temporariliy be larger than the size specified here.
969 */
970extern void halide_memoization_cache_set_size(int64_t size);
971
972/** Given a cache key for a memoized result, currently constructed
973 * from the Func name and top-level Func name plus the arguments of
974 * the computation, determine if the result is in the cache and
975 * return it if so. (The internals of the cache key should be
976 * considered opaque by this function.) If this routine returns true,
977 * it is a cache miss. Otherwise, it will return false and the
978 * buffers passed in will be filled, via copying, with memoized
979 * data. The last argument is a list if halide_buffer_t pointers which
980 * represents the outputs of the memoized Func. If the Func does not
981 * return a Tuple, there will only be one halide_buffer_t in the list. The
982 * tuple_count parameters determines the length of the list.
983 *
984 * The return values are:
985 * -1: Signals an error.
986 * 0: Success and cache hit.
987 * 1: Success and cache miss.
988 */
989extern int halide_memoization_cache_lookup(void *user_context, const uint8_t *cache_key, int32_t size,
990 struct halide_buffer_t *realized_bounds,
991 int32_t tuple_count, struct halide_buffer_t **tuple_buffers);
992
993/** Given a cache key for a memoized result, currently constructed
994 * from the Func name and top-level Func name plus the arguments of
995 * the computation, store the result in the cache for futre access by
996 * halide_memoization_cache_lookup. (The internals of the cache key
997 * should be considered opaque by this function.) Data is copied out
998 * from the inputs and inputs are unmodified. The last argument is a
999 * list if halide_buffer_t pointers which represents the outputs of the
1000 * memoized Func. If the Func does not return a Tuple, there will
1001 * only be one halide_buffer_t in the list. The tuple_count parameters
1002 * determines the length of the list.
1003 *
1004 * If there is a memory allocation failure, the store does not store
1005 * the data into the cache.
1006 *
1007 * If has_eviction_key is true, the entry is marked with eviction_key to
1008 * allow removing the key with halide_memoization_cache_evict.
1009 */
1010extern int halide_memoization_cache_store(void *user_context, const uint8_t *cache_key, int32_t size,
1011 struct halide_buffer_t *realized_bounds,
1012 int32_t tuple_count,
1013 struct halide_buffer_t **tuple_buffers,
1014 bool has_eviction_key, uint64_t eviction_key);
1015
1016/** Evict all cache entries that were tagged with the given
1017 * eviction_key in the memoize scheduling directive.
1018 */
1019extern void halide_memoization_cache_evict(void *user_context, uint64_t eviction_key);
1020
1021/** If halide_memoization_cache_lookup succeeds,
1022 * halide_memoization_cache_release must be called to signal the
1023 * storage is no longer being used by the caller. It will be passed
1024 * the host pointer of one the buffers returned by
1025 * halide_memoization_cache_lookup. That is
1026 * halide_memoization_cache_release will be called multiple times for
1027 * the case where halide_memoization_cache_lookup is handling multiple
1028 * buffers. (This corresponds to memoizing a Tuple in Halide.) Note
1029 * that the host pointer must be sufficient to get to all information
1030 * the release operation needs. The default Halide cache impleemntation
1031 * accomplishes this by storing extra data before the start of the user
1032 * modifiable host storage.
1033 *
1034 * This call is like free and does not have a failure return.
1035 */
1036extern void halide_memoization_cache_release(void *user_context, void *host);
1037
1038/** Free all memory and resources associated with the memoization cache.
1039 * Must be called at a time when no other threads are accessing the cache.
1040 */
1041extern void halide_memoization_cache_cleanup();
1042
1043/** Verify that a given range of memory has been initialized; only used when Target::MSAN is enabled.
1044 *
1045 * The default implementation simply calls the LLVM-provided __msan_check_mem_is_initialized() function.
1046 *
1047 * The return value should always be zero.
1048 */
1049extern int halide_msan_check_memory_is_initialized(void *user_context, const void *ptr, uint64_t len, const char *name);
1050
1051/** Verify that the data pointed to by the halide_buffer_t is initialized (but *not* the halide_buffer_t itself),
1052 * using halide_msan_check_memory_is_initialized() for checking.
1053 *
1054 * The default implementation takes pains to only check the active memory ranges
1055 * (skipping padding), and sorting into ranges to always check the smallest number of
1056 * ranges, in monotonically increasing memory order.
1057 *
1058 * Most client code should never need to replace the default implementation.
1059 *
1060 * The return value should always be zero.
1061 */
1062extern int halide_msan_check_buffer_is_initialized(void *user_context, struct halide_buffer_t *buffer, const char *buf_name);
1063
1064/** Annotate that a given range of memory has been initialized;
1065 * only used when Target::MSAN is enabled.
1066 *
1067 * The default implementation simply calls the LLVM-provided __msan_unpoison() function.
1068 *
1069 * The return value should always be zero.
1070 */
1071extern int halide_msan_annotate_memory_is_initialized(void *user_context, const void *ptr, uint64_t len);
1072
1073/** Mark the data pointed to by the halide_buffer_t as initialized (but *not* the halide_buffer_t itself),
1074 * using halide_msan_annotate_memory_is_initialized() for marking.
1075 *
1076 * The default implementation takes pains to only mark the active memory ranges
1077 * (skipping padding), and sorting into ranges to always mark the smallest number of
1078 * ranges, in monotonically increasing memory order.
1079 *
1080 * Most client code should never need to replace the default implementation.
1081 *
1082 * The return value should always be zero.
1083 */
1084extern int halide_msan_annotate_buffer_is_initialized(void *user_context, struct halide_buffer_t *buffer);
1085extern void halide_msan_annotate_buffer_is_initialized_as_destructor(void *user_context, void *buffer);
1086
1087/** The error codes that may be returned by a Halide pipeline. */
1088enum halide_error_code_t {
1089 /** There was no error. This is the value returned by Halide on success. */
1090 halide_error_code_success = 0,
1091
1092 /** An uncategorized error occurred. Refer to the string passed to halide_error. */
1093 halide_error_code_generic_error = -1,
1094
1095 /** A Func was given an explicit bound via Func::bound, but this
1096 * was not large enough to encompass the region that is used of
1097 * the Func by the rest of the pipeline. */
1098 halide_error_code_explicit_bounds_too_small = -2,
1099
1100 /** The elem_size field of a halide_buffer_t does not match the size in
1101 * bytes of the type of that ImageParam. Probable type mismatch. */
1102 halide_error_code_bad_type = -3,
1103
1104 /** A pipeline would access memory outside of the halide_buffer_t passed
1105 * in. */
1106 halide_error_code_access_out_of_bounds = -4,
1107
1108 /** A halide_buffer_t was given that spans more than 2GB of memory. */
1109 halide_error_code_buffer_allocation_too_large = -5,
1110
1111 /** A halide_buffer_t was given with extents that multiply to a number
1112 * greater than 2^31-1 */
1113 halide_error_code_buffer_extents_too_large = -6,
1114
1115 /** Applying explicit constraints on the size of an input or
1116 * output buffer shrank the size of that buffer below what will be
1117 * accessed by the pipeline. */
1118 halide_error_code_constraints_make_required_region_smaller = -7,
1119
1120 /** A constraint on a size or stride of an input or output buffer
1121 * was not met by the halide_buffer_t passed in. */
1122 halide_error_code_constraint_violated = -8,
1123
1124 /** A scalar parameter passed in was smaller than its minimum
1125 * declared value. */
1126 halide_error_code_param_too_small = -9,
1127
1128 /** A scalar parameter passed in was greater than its minimum
1129 * declared value. */
1130 halide_error_code_param_too_large = -10,
1131
1132 /** A call to halide_malloc returned NULL. */
1133 halide_error_code_out_of_memory = -11,
1134
1135 /** A halide_buffer_t pointer passed in was NULL. */
1136 halide_error_code_buffer_argument_is_null = -12,
1137
1138 /** debug_to_file failed to open or write to the specified
1139 * file. */
1140 halide_error_code_debug_to_file_failed = -13,
1141
1142 /** The Halide runtime encountered an error while trying to copy
1143 * from device to host. Turn on -debug in your target string to
1144 * see more details. */
1145 halide_error_code_copy_to_host_failed = -14,
1146
1147 /** The Halide runtime encountered an error while trying to copy
1148 * from host to device. Turn on -debug in your target string to
1149 * see more details. */
1150 halide_error_code_copy_to_device_failed = -15,
1151
1152 /** The Halide runtime encountered an error while trying to
1153 * allocate memory on device. Turn on -debug in your target string
1154 * to see more details. */
1155 halide_error_code_device_malloc_failed = -16,
1156
1157 /** The Halide runtime encountered an error while trying to
1158 * synchronize with a device. Turn on -debug in your target string
1159 * to see more details. */
1160 halide_error_code_device_sync_failed = -17,
1161
1162 /** The Halide runtime encountered an error while trying to free a
1163 * device allocation. Turn on -debug in your target string to see
1164 * more details. */
1165 halide_error_code_device_free_failed = -18,
1166
1167 /** Buffer has a non-zero device but no device interface, which
1168 * violates a Halide invariant. */
1169 halide_error_code_no_device_interface = -19,
1170
1171 /** An error occurred when attempting to initialize the Matlab
1172 * runtime. */
1173 halide_error_code_matlab_init_failed = -20,
1174
1175 /** The type of an mxArray did not match the expected type. */
1176 halide_error_code_matlab_bad_param_type = -21,
1177
1178 /** There is a bug in the Halide compiler. */
1179 halide_error_code_internal_error = -22,
1180
1181 /** The Halide runtime encountered an error while trying to launch
1182 * a GPU kernel. Turn on -debug in your target string to see more
1183 * details. */
1184 halide_error_code_device_run_failed = -23,
1185
1186 /** The Halide runtime encountered a host pointer that violated
1187 * the alignment set for it by way of a call to
1188 * set_host_alignment */
1189 halide_error_code_unaligned_host_ptr = -24,
1190
1191 /** A fold_storage directive was used on a dimension that is not
1192 * accessed in a monotonically increasing or decreasing fashion. */
1193 halide_error_code_bad_fold = -25,
1194
1195 /** A fold_storage directive was used with a fold factor that was
1196 * too small to store all the values of a producer needed by the
1197 * consumer. */
1198 halide_error_code_fold_factor_too_small = -26,
1199
1200 /** User-specified require() expression was not satisfied. */
1201 halide_error_code_requirement_failed = -27,
1202
1203 /** At least one of the buffer's extents are negative. */
1204 halide_error_code_buffer_extents_negative = -28,
1205
1206 halide_error_code_unused_29 = -29,
1207
1208 halide_error_code_unused_30 = -30,
1209
1210 /** A specialize_fail() schedule branch was selected at runtime. */
1211 halide_error_code_specialize_fail = -31,
1212
1213 /** The Halide runtime encountered an error while trying to wrap a
1214 * native device handle. Turn on -debug in your target string to
1215 * see more details. */
1216 halide_error_code_device_wrap_native_failed = -32,
1217
1218 /** The Halide runtime encountered an error while trying to detach
1219 * a native device handle. Turn on -debug in your target string
1220 * to see more details. */
1221 halide_error_code_device_detach_native_failed = -33,
1222
1223 /** The host field on an input or output was null, the device
1224 * field was not zero, and the pipeline tries to use the buffer on
1225 * the host. You may be passing a GPU-only buffer to a pipeline
1226 * which is scheduled to use it on the CPU. */
1227 halide_error_code_host_is_null = -34,
1228
1229 /** A folded buffer was passed to an extern stage, but the region
1230 * touched wraps around the fold boundary. */
1231 halide_error_code_bad_extern_fold = -35,
1232
1233 /** Buffer has a non-null device_interface but device is 0, which
1234 * violates a Halide invariant. */
1235 halide_error_code_device_interface_no_device = -36,
1236
1237 /** Buffer has both host and device dirty bits set, which violates
1238 * a Halide invariant. */
1239 halide_error_code_host_and_device_dirty = -37,
1240
1241 /** The halide_buffer_t * passed to a halide runtime routine is
1242 * nullptr and this is not allowed. */
1243 halide_error_code_buffer_is_null = -38,
1244
1245 /** The Halide runtime encountered an error while trying to copy
1246 * from one buffer to another. Turn on -debug in your target
1247 * string to see more details. */
1248 halide_error_code_device_buffer_copy_failed = -39,
1249
1250 /** Attempted to make cropped/sliced alias of a buffer with a device
1251 * field, but the device_interface does not support cropping. */
1252 halide_error_code_device_crop_unsupported = -40,
1253
1254 /** Cropping/slicing a buffer failed for some other reason. Turn on -debug
1255 * in your target string. */
1256 halide_error_code_device_crop_failed = -41,
1257
1258 /** An operation on a buffer required an allocation on a
1259 * particular device interface, but a device allocation already
1260 * existed on a different device interface. Free the old one
1261 * first. */
1262 halide_error_code_incompatible_device_interface = -42,
1263
1264 /** The dimensions field of a halide_buffer_t does not match the dimensions of that ImageParam. */
1265 halide_error_code_bad_dimensions = -43,
1266
1267 /** A buffer with the device_dirty flag set was passed to a
1268 * pipeline compiled with no device backends enabled, so it
1269 * doesn't know how to copy the data back from device memory to
1270 * host memory. Either call copy_to_host before calling the Halide
1271 * pipeline, or enable the appropriate device backend. */
1272 halide_error_code_device_dirty_with_no_device_support = -44,
1273
1274};
1275
1276/** Halide calls the functions below on various error conditions. The
1277 * default implementations construct an error message, call
1278 * halide_error, then return the matching error code above. On
1279 * platforms that support weak linking, you can override these to
1280 * catch the errors individually. */
1281
1282/** A call into an extern stage for the purposes of bounds inference
1283 * failed. Returns the error code given by the extern stage. */
1284extern int halide_error_bounds_inference_call_failed(void *user_context, const char *extern_stage_name, int result);
1285
1286/** A call to an extern stage failed. Returned the error code given by
1287 * the extern stage. */
1288extern int halide_error_extern_stage_failed(void *user_context, const char *extern_stage_name, int result);
1289
1290/** Various other error conditions. See the enum above for a
1291 * description of each. */
1292// @{
1293extern int halide_error_explicit_bounds_too_small(void *user_context, const char *func_name, const char *var_name,
1294 int min_bound, int max_bound, int min_required, int max_required);
1295extern int halide_error_bad_type(void *user_context, const char *func_name,
1296 uint32_t type_given, uint32_t correct_type); // N.B. The last two args are the bit representation of a halide_type_t
1297extern int halide_error_bad_dimensions(void *user_context, const char *func_name,
1298 int32_t dimensions_given, int32_t correct_dimensions);
1299extern int halide_error_access_out_of_bounds(void *user_context, const char *func_name,
1300 int dimension, int min_touched, int max_touched,
1301 int min_valid, int max_valid);
1302extern int halide_error_buffer_allocation_too_large(void *user_context, const char *buffer_name,
1303 uint64_t allocation_size, uint64_t max_size);
1304extern int halide_error_buffer_extents_negative(void *user_context, const char *buffer_name, int dimension, int extent);
1305extern int halide_error_buffer_extents_too_large(void *user_context, const char *buffer_name,
1306 int64_t actual_size, int64_t max_size);
1307extern int halide_error_constraints_make_required_region_smaller(void *user_context, const char *buffer_name,
1308 int dimension,
1309 int constrained_min, int constrained_extent,
1310 int required_min, int required_extent);
1311extern int halide_error_constraint_violated(void *user_context, const char *var, int val,
1312 const char *constrained_var, int constrained_val);
1313extern int halide_error_param_too_small_i64(void *user_context, const char *param_name,
1314 int64_t val, int64_t min_val);
1315extern int halide_error_param_too_small_u64(void *user_context, const char *param_name,
1316 uint64_t val, uint64_t min_val);
1317extern int halide_error_param_too_small_f64(void *user_context, const char *param_name,
1318 double val, double min_val);
1319extern int halide_error_param_too_large_i64(void *user_context, const char *param_name,
1320 int64_t val, int64_t max_val);
1321extern int halide_error_param_too_large_u64(void *user_context, const char *param_name,
1322 uint64_t val, uint64_t max_val);
1323extern int halide_error_param_too_large_f64(void *user_context, const char *param_name,
1324 double val, double max_val);
1325extern int halide_error_out_of_memory(void *user_context);
1326extern int halide_error_buffer_argument_is_null(void *user_context, const char *buffer_name);
1327extern int halide_error_debug_to_file_failed(void *user_context, const char *func,
1328 const char *filename, int error_code);
1329extern int halide_error_unaligned_host_ptr(void *user_context, const char *func_name, int alignment);
1330extern int halide_error_host_is_null(void *user_context, const char *func_name);
1331extern int halide_error_bad_fold(void *user_context, const char *func_name, const char *var_name,
1332 const char *loop_name);
1333extern int halide_error_bad_extern_fold(void *user_context, const char *func_name,
1334 int dim, int min, int extent, int valid_min, int fold_factor);
1335
1336extern int halide_error_fold_factor_too_small(void *user_context, const char *func_name, const char *var_name,
1337 int fold_factor, const char *loop_name, int required_extent);
1338extern int halide_error_requirement_failed(void *user_context, const char *condition, const char *message);
1339extern int halide_error_specialize_fail(void *user_context, const char *message);
1340extern int halide_error_no_device_interface(void *user_context);
1341extern int halide_error_device_interface_no_device(void *user_context);
1342extern int halide_error_host_and_device_dirty(void *user_context);
1343extern int halide_error_buffer_is_null(void *user_context, const char *routine);
1344extern int halide_error_device_dirty_with_no_device_support(void *user_context, const char *buffer_name);
1345// @}
1346
1347/** Optional features a compilation Target can have.
1348 * Be sure to keep this in sync with the Feature enum in Target.h and the implementation of
1349 * get_runtime_compatible_target in Target.cpp if you add a new feature.
1350 */
1351typedef enum halide_target_feature_t {
1352 halide_target_feature_jit = 0, ///< Generate code that will run immediately inside the calling process.
1353 halide_target_feature_debug, ///< Turn on debug info and output for runtime code.
1354 halide_target_feature_no_asserts, ///< Disable all runtime checks, for slightly tighter code.
1355 halide_target_feature_no_bounds_query, ///< Disable the bounds querying functionality.
1356
1357 halide_target_feature_sse41, ///< Use SSE 4.1 and earlier instructions. Only relevant on x86.
1358 halide_target_feature_avx, ///< Use AVX 1 instructions. Only relevant on x86.
1359 halide_target_feature_avx2, ///< Use AVX 2 instructions. Only relevant on x86.
1360 halide_target_feature_fma, ///< Enable x86 FMA instruction
1361 halide_target_feature_fma4, ///< Enable x86 (AMD) FMA4 instruction set
1362 halide_target_feature_f16c, ///< Enable x86 16-bit float support
1363
1364 halide_target_feature_armv7s, ///< Generate code for ARMv7s. Only relevant for 32-bit ARM.
1365 halide_target_feature_no_neon, ///< Avoid using NEON instructions. Only relevant for 32-bit ARM.
1366
1367 halide_target_feature_vsx, ///< Use VSX instructions. Only relevant on POWERPC.
1368 halide_target_feature_power_arch_2_07, ///< Use POWER ISA 2.07 new instructions. Only relevant on POWERPC.
1369
1370 halide_target_feature_cuda, ///< Enable the CUDA runtime. Defaults to compute capability 2.0 (Fermi)
1371 halide_target_feature_cuda_capability30, ///< Enable CUDA compute capability 3.0 (Kepler)
1372 halide_target_feature_cuda_capability32, ///< Enable CUDA compute capability 3.2 (Tegra K1)
1373 halide_target_feature_cuda_capability35, ///< Enable CUDA compute capability 3.5 (Kepler)
1374 halide_target_feature_cuda_capability50, ///< Enable CUDA compute capability 5.0 (Maxwell)
1375 halide_target_feature_cuda_capability61, ///< Enable CUDA compute capability 6.1 (Pascal)
1376 halide_target_feature_cuda_capability70, ///< Enable CUDA compute capability 7.0 (Volta)
1377 halide_target_feature_cuda_capability75, ///< Enable CUDA compute capability 7.5 (Turing)
1378 halide_target_feature_cuda_capability80, ///< Enable CUDA compute capability 8.0 (Ampere)
1379
1380 halide_target_feature_opencl, ///< Enable the OpenCL runtime.
1381 halide_target_feature_cl_doubles, ///< Enable double support on OpenCL targets
1382 halide_target_feature_cl_atomic64, ///< Enable 64-bit atomics operations on OpenCL targets
1383
1384 halide_target_feature_openglcompute, ///< Enable OpenGL Compute runtime.
1385
1386 halide_target_feature_user_context, ///< Generated code takes a user_context pointer as first argument
1387
1388 halide_target_feature_matlab, ///< Generate a mexFunction compatible with Matlab mex libraries. See tools/mex_halide.m.
1389
1390 halide_target_feature_profile, ///< Launch a sampling profiler alongside the Halide pipeline that monitors and reports the runtime used by each Func
1391 halide_target_feature_no_runtime, ///< Do not include a copy of the Halide runtime in any generated object file or assembly
1392
1393 halide_target_feature_metal, ///< Enable the (Apple) Metal runtime.
1394
1395 halide_target_feature_c_plus_plus_mangling, ///< Generate C++ mangled names for result function, et al
1396
1397 halide_target_feature_large_buffers, ///< Enable 64-bit buffer indexing to support buffers > 2GB. Ignored if bits != 64.
1398
1399 halide_target_feature_hvx_128, ///< Enable HVX 128 byte mode.
1400 halide_target_feature_hvx_v62, ///< Enable Hexagon v62 architecture.
1401 halide_target_feature_fuzz_float_stores, ///< On every floating point store, set the last bit of the mantissa to zero. Pipelines for which the output is very different with this feature enabled may also produce very different output on different processors.
1402 halide_target_feature_soft_float_abi, ///< Enable soft float ABI. This only enables the soft float ABI calling convention, which does not necessarily use soft floats.
1403 halide_target_feature_msan, ///< Enable hooks for MSAN support.
1404 halide_target_feature_avx512, ///< Enable the base AVX512 subset supported by all AVX512 architectures. The specific feature sets are AVX-512F and AVX512-CD. See https://en.wikipedia.org/wiki/AVX-512 for a description of each AVX subset.
1405 halide_target_feature_avx512_knl, ///< Enable the AVX512 features supported by Knight's Landing chips, such as the Xeon Phi x200. This includes the base AVX512 set, and also AVX512-CD and AVX512-ER.
1406 halide_target_feature_avx512_skylake, ///< Enable the AVX512 features supported by Skylake Xeon server processors. This adds AVX512-VL, AVX512-BW, and AVX512-DQ to the base set. The main difference from the base AVX512 set is better support for small integer ops. Note that this does not include the Knight's Landing features. Note also that these features are not available on Skylake desktop and mobile processors.
1407 halide_target_feature_avx512_cannonlake, ///< Enable the AVX512 features expected to be supported by future Cannonlake processors. This includes all of the Skylake features, plus AVX512-IFMA and AVX512-VBMI.
1408 halide_target_feature_avx512_sapphirerapids, ///< Enable the AVX512 features supported by Sapphire Rapids processors. This include all of the Cannonlake features, plus AVX512-VNNI and AVX512-BF16.
1409 halide_target_feature_hvx_use_shared_object, ///< Deprecated
1410 halide_target_feature_trace_loads, ///< Trace all loads done by the pipeline. Equivalent to calling Func::trace_loads on every non-inlined Func.
1411 halide_target_feature_trace_stores, ///< Trace all stores done by the pipeline. Equivalent to calling Func::trace_stores on every non-inlined Func.
1412 halide_target_feature_trace_realizations, ///< Trace all realizations done by the pipeline. Equivalent to calling Func::trace_realizations on every non-inlined Func.
1413 halide_target_feature_trace_pipeline, ///< Trace the pipeline.
1414 halide_target_feature_hvx_v65, ///< Enable Hexagon v65 architecture.
1415 halide_target_feature_hvx_v66, ///< Enable Hexagon v66 architecture.
1416 halide_target_feature_cl_half, ///< Enable half support on OpenCL targets
1417 halide_target_feature_strict_float, ///< Turn off all non-IEEE floating-point optimization. Currently applies only to LLVM targets.
1418 halide_target_feature_tsan, ///< Enable hooks for TSAN support.
1419 halide_target_feature_asan, ///< Enable hooks for ASAN support.
1420 halide_target_feature_d3d12compute, ///< Enable Direct3D 12 Compute runtime.
1421 halide_target_feature_check_unsafe_promises, ///< Insert assertions for promises.
1422 halide_target_feature_hexagon_dma, ///< Enable Hexagon DMA buffers.
1423 halide_target_feature_embed_bitcode, ///< Emulate clang -fembed-bitcode flag.
1424 halide_target_feature_enable_llvm_loop_opt, ///< Enable loop vectorization + unrolling in LLVM. Overrides halide_target_feature_disable_llvm_loop_opt. (Ignored for non-LLVM targets.)
1425 halide_target_feature_disable_llvm_loop_opt, ///< Disable loop vectorization + unrolling in LLVM. (Ignored for non-LLVM targets.)
1426 halide_target_feature_wasm_simd128, ///< Enable +simd128 instructions for WebAssembly codegen.
1427 halide_target_feature_wasm_signext, ///< Enable +sign-ext instructions for WebAssembly codegen.
1428 halide_target_feature_wasm_sat_float_to_int, ///< Enable saturating (nontrapping) float-to-int instructions for WebAssembly codegen.
1429 halide_target_feature_wasm_threads, ///< Enable use of threads in WebAssembly codegen. Requires the use of a wasm runtime that provides pthread-compatible wrappers (typically, Emscripten with the -pthreads flag). Unsupported under WASI.
1430 halide_target_feature_wasm_bulk_memory, ///< Enable +bulk-memory instructions for WebAssembly codegen.
1431 halide_target_feature_sve, ///< Enable ARM Scalable Vector Extensions
1432 halide_target_feature_sve2, ///< Enable ARM Scalable Vector Extensions v2
1433 halide_target_feature_egl, ///< Force use of EGL support.
1434 halide_target_feature_arm_dot_prod, ///< Enable ARMv8.2-a dotprod extension (i.e. udot and sdot instructions)
1435 halide_llvm_large_code_model, ///< Use the LLVM large code model to compile
1436 halide_target_feature_rvv, ///< Enable RISCV "V" Vector Extension
1437 halide_target_feature_armv81a, ///< Enable ARMv8.1-a instructions
1438 halide_target_feature_end ///< A sentinel. Every target is considered to have this feature, and setting this feature does nothing.
1439} halide_target_feature_t;
1440
1441/** This function is called internally by Halide in some situations to determine
1442 * if the current execution environment can support the given set of
1443 * halide_target_feature_t flags. The implementation must do the following:
1444 *
1445 * -- If there are flags set in features that the function knows *cannot* be supported, return 0.
1446 * -- Otherwise, return 1.
1447 * -- Note that any flags set in features that the function doesn't know how to test should be ignored;
1448 * this implies that a return value of 1 means "not known to be bad" rather than "known to be good".
1449 *
1450 * In other words: a return value of 0 means "It is not safe to use code compiled with these features",
1451 * while a return value of 1 means "It is not obviously unsafe to use code compiled with these features".
1452 *
1453 * The default implementation simply calls halide_default_can_use_target_features.
1454 *
1455 * Note that `features` points to an array of `count` uint64_t; this array must contain enough
1456 * bits to represent all the currently known features. Any excess bits must be set to zero.
1457 */
1458// @{
1459extern int halide_can_use_target_features(int count, const uint64_t *features);
1460typedef int (*halide_can_use_target_features_t)(int count, const uint64_t *features);
1461extern halide_can_use_target_features_t halide_set_custom_can_use_target_features(halide_can_use_target_features_t);
1462// @}
1463
1464/**
1465 * This is the default implementation of halide_can_use_target_features; it is provided
1466 * for convenience of user code that may wish to extend halide_can_use_target_features
1467 * but continue providing existing support, e.g.
1468 *
1469 * int halide_can_use_target_features(int count, const uint64_t *features) {
1470 * if (features[halide_target_somefeature >> 6] & (1LL << (halide_target_somefeature & 63))) {
1471 * if (!can_use_somefeature()) {
1472 * return 0;
1473 * }
1474 * }
1475 * return halide_default_can_use_target_features(count, features);
1476 * }
1477 */
1478extern int halide_default_can_use_target_features(int count, const uint64_t *features);
1479
1480typedef struct halide_dimension_t {
1481#if (__cplusplus >= 201103L || _MSVC_LANG >= 201103L)
1482 int32_t min = 0, extent = 0, stride = 0;
1483
1484 // Per-dimension flags. None are defined yet (This is reserved for future use).
1485 uint32_t flags = 0;
1486
1487 HALIDE_ALWAYS_INLINE halide_dimension_t() = default;
1488 HALIDE_ALWAYS_INLINE halide_dimension_t(int32_t m, int32_t e, int32_t s, uint32_t f = 0)
1489 : min(m), extent(e), stride(s), flags(f) {
1490 }
1491
1492 HALIDE_ALWAYS_INLINE bool operator==(const halide_dimension_t &other) const {
1493 return (min == other.min) &&
1494 (extent == other.extent) &&
1495 (stride == other.stride) &&
1496 (flags == other.flags);
1497 }
1498
1499 HALIDE_ALWAYS_INLINE bool operator!=(const halide_dimension_t &other) const {
1500 return !(*this == other);
1501 }
1502#else
1503 int32_t min, extent, stride;
1504
1505 // Per-dimension flags. None are defined yet (This is reserved for future use).
1506 uint32_t flags;
1507#endif
1508} halide_dimension_t;
1509
1510#ifdef __cplusplus
1511} // extern "C"
1512#endif
1513
1514typedef enum { halide_buffer_flag_host_dirty = 1,
1515 halide_buffer_flag_device_dirty = 2 } halide_buffer_flags;
1516
1517/**
1518 * The raw representation of an image passed around by generated
1519 * Halide code. It includes some stuff to track whether the image is
1520 * not actually in main memory, but instead on a device (like a
1521 * GPU). For a more convenient C++ wrapper, use Halide::Buffer<T>. */
1522typedef struct halide_buffer_t {
1523 /** A device-handle for e.g. GPU memory used to back this buffer. */
1524 uint64_t device;
1525
1526 /** The interface used to interpret the above handle. */
1527 const struct halide_device_interface_t *device_interface;
1528
1529 /** A pointer to the start of the data in main memory. In terms of
1530 * the Halide coordinate system, this is the address of the min
1531 * coordinates (defined below). */
1532 uint8_t *host;
1533
1534 /** flags with various meanings. */
1535 uint64_t flags;
1536
1537 /** The type of each buffer element. */
1538 struct halide_type_t type;
1539
1540 /** The dimensionality of the buffer. */
1541 int32_t dimensions;
1542
1543 /** The shape of the buffer. Halide does not own this array - you
1544 * must manage the memory for it yourself. */
1545 halide_dimension_t *dim;
1546
1547 /** Pads the buffer up to a multiple of 8 bytes */
1548 void *padding;
1549
1550#if (__cplusplus >= 201103L || _MSVC_LANG >= 201103L)
1551 /** Convenience methods for accessing the flags */
1552 // @{
1553 HALIDE_ALWAYS_INLINE bool get_flag(halide_buffer_flags flag) const {
1554 return (flags & flag) != 0;
1555 }
1556
1557 HALIDE_ALWAYS_INLINE void set_flag(halide_buffer_flags flag, bool value) {
1558 if (value) {
1559 flags |= flag;
1560 } else {
1561 flags &= ~flag;
1562 }
1563 }
1564
1565 HALIDE_ALWAYS_INLINE bool host_dirty() const {
1566 return get_flag(halide_buffer_flag_host_dirty);
1567 }
1568
1569 HALIDE_ALWAYS_INLINE bool device_dirty() const {
1570 return get_flag(halide_buffer_flag_device_dirty);
1571 }
1572
1573 HALIDE_ALWAYS_INLINE void set_host_dirty(bool v = true) {
1574 set_flag(halide_buffer_flag_host_dirty, v);
1575 }
1576
1577 HALIDE_ALWAYS_INLINE void set_device_dirty(bool v = true) {
1578 set_flag(halide_buffer_flag_device_dirty, v);
1579 }
1580 // @}
1581
1582 /** The total number of elements this buffer represents. Equal to
1583 * the product of the extents */
1584 HALIDE_ALWAYS_INLINE size_t number_of_elements() const {
1585 size_t s = 1;
1586 for (int i = 0; i < dimensions; i++) {
1587 s *= dim[i].extent;
1588 }
1589 return s;
1590 }
1591
1592 /** Offset to the element with the lowest address.
1593 * If all strides are positive, equal to zero.
1594 * Offset is in elements, not bytes.
1595 * Unlike begin(), this is ok to call on an unallocated buffer. */
1596 HALIDE_ALWAYS_INLINE ptrdiff_t begin_offset() const {
1597 ptrdiff_t index = 0;
1598 for (int i = 0; i < dimensions; i++) {
1599 const int stride = dim[i].stride;
1600 if (stride < 0) {
1601 index += stride * (ptrdiff_t)(dim[i].extent - 1);
1602 }
1603 }
1604 return index;
1605 }
1606
1607 /** An offset to one beyond the element with the highest address.
1608 * Offset is in elements, not bytes.
1609 * Unlike end(), this is ok to call on an unallocated buffer. */
1610 HALIDE_ALWAYS_INLINE ptrdiff_t end_offset() const {
1611 ptrdiff_t index = 0;
1612 for (int i = 0; i < dimensions; i++) {
1613 const int stride = dim[i].stride;
1614 if (stride > 0) {
1615 index += stride * (ptrdiff_t)(dim[i].extent - 1);
1616 }
1617 }
1618 index += 1;
1619 return index;
1620 }
1621
1622 /** A pointer to the element with the lowest address.
1623 * If all strides are positive, equal to the host pointer.
1624 * Illegal to call on an unallocated buffer. */
1625 HALIDE_ALWAYS_INLINE uint8_t *begin() const {
1626 return host + begin_offset() * type.bytes();
1627 }
1628
1629 /** A pointer to one beyond the element with the highest address.
1630 * Illegal to call on an unallocated buffer. */
1631 HALIDE_ALWAYS_INLINE uint8_t *end() const {
1632 return host + end_offset() * type.bytes();
1633 }
1634
1635 /** The total number of bytes spanned by the data in memory. */
1636 HALIDE_ALWAYS_INLINE size_t size_in_bytes() const {
1637 return (size_t)(end_offset() - begin_offset()) * type.bytes();
1638 }
1639
1640 /** A pointer to the element at the given location. */
1641 HALIDE_ALWAYS_INLINE uint8_t *address_of(const int *pos) const {
1642 ptrdiff_t index = 0;
1643 for (int i = 0; i < dimensions; i++) {
1644 index += (ptrdiff_t)dim[i].stride * (pos[i] - dim[i].min);
1645 }
1646 return host + index * type.bytes();
1647 }
1648
1649 /** Attempt to call device_sync for the buffer. If the buffer
1650 * has no device_interface (or no device_sync), this is a quiet no-op.
1651 * Calling this explicitly should rarely be necessary, except for profiling. */
1652 HALIDE_ALWAYS_INLINE int device_sync(void *ctx = nullptr) {
1653 if (device_interface && device_interface->device_sync) {
1654 return device_interface->device_sync(ctx, this);
1655 }
1656 return 0;
1657 }
1658
1659 /** Check if an input buffer passed extern stage is a querying
1660 * bounds. Compared to doing the host pointer check directly,
1661 * this both adds clarity to code and will facilitate moving to
1662 * another representation for bounds query arguments. */
1663 HALIDE_ALWAYS_INLINE bool is_bounds_query() const {
1664 return host == nullptr && device == 0;
1665 }
1666
1667#endif
1668} halide_buffer_t;
1669
1670#ifdef __cplusplus
1671extern "C" {
1672#endif
1673
1674#ifndef HALIDE_ATTRIBUTE_DEPRECATED
1675#ifdef HALIDE_ALLOW_DEPRECATED
1676#define HALIDE_ATTRIBUTE_DEPRECATED(x)
1677#else
1678#ifdef _MSC_VER
1679#define HALIDE_ATTRIBUTE_DEPRECATED(x) __declspec(deprecated(x))
1680#else
1681#define HALIDE_ATTRIBUTE_DEPRECATED(x) __attribute__((deprecated(x)))
1682#endif
1683#endif
1684#endif
1685
1686/** halide_scalar_value_t is a simple union able to represent all the well-known
1687 * scalar values in a filter argument. Note that it isn't tagged with a type;
1688 * you must ensure you know the proper type before accessing. Most user
1689 * code will never need to create instances of this struct; its primary use
1690 * is to hold def/min/max values in a halide_filter_argument_t. (Note that
1691 * this is conceptually just a union; it's wrapped in a struct to ensure
1692 * that it doesn't get anonymized by LLVM.)
1693 */
1694struct halide_scalar_value_t {
1695 union {
1696 bool b;
1697 int8_t i8;
1698 int16_t i16;
1699 int32_t i32;
1700 int64_t i64;
1701 uint8_t u8;
1702 uint16_t u16;
1703 uint32_t u32;
1704 uint64_t u64;
1705 float f32;
1706 double f64;
1707 void *handle;
1708 } u;
1709#ifdef __cplusplus
1710 HALIDE_ALWAYS_INLINE halide_scalar_value_t() {
1711 u.u64 = 0;
1712 }
1713#endif
1714};
1715
1716enum halide_argument_kind_t {
1717 halide_argument_kind_input_scalar = 0,
1718 halide_argument_kind_input_buffer = 1,
1719 halide_argument_kind_output_buffer = 2
1720};
1721
1722/*
1723 These structs must be robust across different compilers and settings; when
1724 modifying them, strive for the following rules:
1725
1726 1) All fields are explicitly sized. I.e. must use int32_t and not "int"
1727 2) All fields must land on an alignment boundary that is the same as their size
1728 3) Explicit padding is added to make that so
1729 4) The sizeof the struct is padded out to a multiple of the largest natural size thing in the struct
1730 5) don't forget that 32 and 64 bit pointers are different sizes
1731*/
1732
1733/**
1734 * Obsolete version of halide_filter_argument_t; only present in
1735 * code that wrote halide_filter_metadata_t version 0.
1736 */
1737struct halide_filter_argument_t_v0 {
1738 const char *name;
1739 int32_t kind;
1740 int32_t dimensions;
1741 struct halide_type_t type;
1742 const struct halide_scalar_value_t *def, *min, *max;
1743};
1744
1745/**
1746 * halide_filter_argument_t is essentially a plain-C-struct equivalent to
1747 * Halide::Argument; most user code will never need to create one.
1748 */
1749struct halide_filter_argument_t {
1750 const char *name; // name of the argument; will never be null or empty.
1751 int32_t kind; // actually halide_argument_kind_t
1752 int32_t dimensions; // always zero for scalar arguments
1753 struct halide_type_t type;
1754 // These pointers should always be null for buffer arguments,
1755 // and *may* be null for scalar arguments. (A null value means
1756 // there is no def/min/max/estimate specified for this argument.)
1757 const struct halide_scalar_value_t *scalar_def, *scalar_min, *scalar_max, *scalar_estimate;
1758 // This pointer should always be null for scalar arguments,
1759 // and *may* be null for buffer arguments. If not null, it should always
1760 // point to an array of dimensions*2 pointers, which will be the (min, extent)
1761 // estimates for each dimension of the buffer. (Note that any of the pointers
1762 // may be null as well.)
1763 int64_t const *const *buffer_estimates;
1764};
1765
1766struct halide_filter_metadata_t {
1767#ifdef __cplusplus
1768 static const int32_t VERSION = 1;
1769#endif
1770
1771 /** version of this metadata; currently always 1. */
1772 int32_t version;
1773
1774 /** The number of entries in the arguments field. This is always >= 1. */
1775 int32_t num_arguments;
1776
1777 /** An array of the filters input and output arguments; this will never be
1778 * null. The order of arguments is not guaranteed (input and output arguments
1779 * may come in any order); however, it is guaranteed that all arguments
1780 * will have a unique name within a given filter. */
1781 const struct halide_filter_argument_t *arguments;
1782
1783 /** The Target for which the filter was compiled. This is always
1784 * a canonical Target string (ie a product of Target::to_string). */
1785 const char *target;
1786
1787 /** The function name of the filter. */
1788 const char *name;
1789};
1790
1791/** halide_register_argv_and_metadata() is a **user-defined** function that
1792 * must be provided in order to use the registration.cc files produced
1793 * by Generators when the 'registration' output is requested. Each registration.cc
1794 * file provides a static initializer that calls this function with the given
1795 * filter's argv-call variant, its metadata, and (optionally) and additional
1796 * textual data that the build system chooses to tack on for its own purposes.
1797 * Note that this will be called at static-initializer time (i.e., before
1798 * main() is called), and in an unpredictable order. Note that extra_key_value_pairs
1799 * may be nullptr; if it's not null, it's expected to be a null-terminated list
1800 * of strings, with an even number of entries. */
1801void halide_register_argv_and_metadata(
1802 int (*filter_argv_call)(void **),
1803 const struct halide_filter_metadata_t *filter_metadata,
1804 const char *const *extra_key_value_pairs);
1805
1806/** The functions below here are relevant for pipelines compiled with
1807 * the -profile target flag, which runs a sampling profiler thread
1808 * alongside the pipeline. */
1809
1810/** Per-Func state tracked by the sampling profiler. */
1811struct halide_profiler_func_stats {
1812 /** Total time taken evaluating this Func (in nanoseconds). */
1813 uint64_t time;
1814
1815 /** The current memory allocation of this Func. */
1816 uint64_t memory_current;
1817
1818 /** The peak memory allocation of this Func. */
1819 uint64_t memory_peak;
1820
1821 /** The total memory allocation of this Func. */
1822 uint64_t memory_total;
1823
1824 /** The peak stack allocation of this Func's threads. */
1825 uint64_t stack_peak;
1826
1827 /** The average number of thread pool worker threads active while computing this Func. */
1828 uint64_t active_threads_numerator, active_threads_denominator;
1829
1830 /** The name of this Func. A global constant string. */
1831 const char *name;
1832
1833 /** The total number of memory allocation of this Func. */
1834 int num_allocs;
1835};
1836
1837/** Per-pipeline state tracked by the sampling profiler. These exist
1838 * in a linked list. */
1839struct halide_profiler_pipeline_stats {
1840 /** Total time spent inside this pipeline (in nanoseconds) */
1841 uint64_t time;
1842
1843 /** The current memory allocation of funcs in this pipeline. */
1844 uint64_t memory_current;
1845
1846 /** The peak memory allocation of funcs in this pipeline. */
1847 uint64_t memory_peak;
1848
1849 /** The total memory allocation of funcs in this pipeline. */
1850 uint64_t memory_total;
1851
1852 /** The average number of thread pool worker threads doing useful
1853 * work while computing this pipeline. */
1854 uint64_t active_threads_numerator, active_threads_denominator;
1855
1856 /** The name of this pipeline. A global constant string. */
1857 const char *name;
1858
1859 /** An array containing states for each Func in this pipeline. */
1860 struct halide_profiler_func_stats *funcs;
1861
1862 /** The next pipeline_stats pointer. It's a void * because types
1863 * in the Halide runtime may not currently be recursive. */
1864 void *next;
1865
1866 /** The number of funcs in this pipeline. */
1867 int num_funcs;
1868
1869 /** An internal base id used to identify the funcs in this pipeline. */
1870 int first_func_id;
1871
1872 /** The number of times this pipeline has been run. */
1873 int runs;
1874
1875 /** The total number of samples taken inside of this pipeline. */
1876 int samples;
1877
1878 /** The total number of memory allocation of funcs in this pipeline. */
1879 int num_allocs;
1880};
1881
1882/** The global state of the profiler. */
1883
1884struct halide_profiler_state {
1885 /** Guards access to the fields below. If not locked, the sampling
1886 * profiler thread is free to modify things below (including
1887 * reordering the linked list of pipeline stats). */
1888 struct halide_mutex lock;
1889
1890 /** The amount of time the profiler thread sleeps between samples
1891 * in milliseconds. Defaults to 1 */
1892 int sleep_time;
1893
1894 /** An internal id used for bookkeeping. */
1895 int first_free_id;
1896
1897 /** The id of the current running Func. Set by the pipeline, read
1898 * periodically by the profiler thread. */
1899 int current_func;
1900
1901 /** The number of threads currently doing work. */
1902 int active_threads;
1903
1904 /** A linked list of stats gathered for each pipeline. */
1905 struct halide_profiler_pipeline_stats *pipelines;
1906
1907 /** Retrieve remote profiler state. Used so that the sampling
1908 * profiler can follow along with execution that occurs elsewhere,
1909 * e.g. on a DSP. If null, it reads from the int above instead. */
1910 void (*get_remote_profiler_state)(int *func, int *active_workers);
1911
1912 /** Sampling thread reference to be joined at shutdown. */
1913 struct halide_thread *sampling_thread;
1914};
1915
1916/** Profiler func ids with special meanings. */
1917enum {
1918 /// current_func takes on this value when not inside Halide code
1919 halide_profiler_outside_of_halide = -1,
1920 /// Set current_func to this value to tell the profiling thread to
1921 /// halt. It will start up again next time you run a pipeline with
1922 /// profiling enabled.
1923 halide_profiler_please_stop = -2
1924};
1925
1926/** Get a pointer to the global profiler state for programmatic
1927 * inspection. Lock it before using to pause the profiler. */
1928extern struct halide_profiler_state *halide_profiler_get_state();
1929
1930/** Get a pointer to the pipeline state associated with pipeline_name.
1931 * This function grabs the global profiler state's lock on entry. */
1932extern struct halide_profiler_pipeline_stats *halide_profiler_get_pipeline_state(const char *pipeline_name);
1933
1934/** Reset profiler state cheaply. May leave threads running or some
1935 * memory allocated but all accumluated statistics are reset.
1936 * WARNING: Do NOT call this method while any halide pipeline is
1937 * running; halide_profiler_memory_allocate/free and
1938 * halide_profiler_stack_peak_update update the profiler pipeline's
1939 * state without grabbing the global profiler state's lock. */
1940extern void halide_profiler_reset();
1941
1942/** Reset all profiler state.
1943 * WARNING: Do NOT call this method while any halide pipeline is
1944 * running; halide_profiler_memory_allocate/free and
1945 * halide_profiler_stack_peak_update update the profiler pipeline's
1946 * state without grabbing the global profiler state's lock. */
1947void halide_profiler_shutdown();
1948
1949/** Print out timing statistics for everything run since the last
1950 * reset. Also happens at process exit. */
1951extern void halide_profiler_report(void *user_context);
1952
1953/// \name "Float16" functions
1954/// These functions operate of bits (``uint16_t``) representing a half
1955/// precision floating point number (IEEE-754 2008 binary16).
1956//{@
1957
1958/** Read bits representing a half precision floating point number and return
1959 * the float that represents the same value */
1960extern float halide_float16_bits_to_float(uint16_t);
1961
1962/** Read bits representing a half precision floating point number and return
1963 * the double that represents the same value */
1964extern double halide_float16_bits_to_double(uint16_t);
1965
1966// TODO: Conversion functions to half
1967
1968//@}
1969
1970// Allocating and freeing device memory is often very slow. The
1971// methods below give Halide's runtime permission to hold onto device
1972// memory to service future requests instead of returning it to the
1973// underlying device API. The API does not manage an allocation pool,
1974// all it does is provide access to a shared counter that acts as a
1975// limit on the unused memory not yet returned to the underlying
1976// device API. It makes callbacks to participants when memory needs to
1977// be released because the limit is about to be exceeded (either
1978// because the limit has been reduced, or because the memory owned by
1979// some participant becomes unused).
1980
1981/** Tell Halide whether or not it is permitted to hold onto device
1982 * allocations to service future requests instead of returning them
1983 * eagerly to the underlying device API. Many device allocators are
1984 * quite slow, so it can be beneficial to set this to true. The
1985 * default value for now is false.
1986 *
1987 * Note that if enabled, the eviction policy is very simplistic. The
1988 * 32 most-recently used allocations are preserved, regardless of
1989 * their size. Additionally, if a call to cuMalloc results in an
1990 * out-of-memory error, the entire cache is flushed and the allocation
1991 * is retried. See https://github.com/halide/Halide/issues/4093
1992 *
1993 * If set to false, releases all unused device allocations back to the
1994 * underlying device APIs. For finer-grained control, see specific
1995 * methods in each device api runtime. */
1996extern int halide_reuse_device_allocations(void *user_context, bool);
1997
1998/** Determines whether on device_free the memory is returned
1999 * immediately to the device API, or placed on a free list for future
2000 * use. Override and switch based on the user_context for
2001 * finer-grained control. By default just returns the value most
2002 * recently set by the method above. */
2003extern bool halide_can_reuse_device_allocations(void *user_context);
2004
2005struct halide_device_allocation_pool {
2006 int (*release_unused)(void *user_context);
2007 struct halide_device_allocation_pool *next;
2008};
2009
2010/** Register a callback to be informed when
2011 * halide_reuse_device_allocations(false) is called, and all unused
2012 * device allocations must be released. The object passed should have
2013 * global lifetime, and its next field will be clobbered. */
2014extern void halide_register_device_allocation_pool(struct halide_device_allocation_pool *);
2015
2016#ifdef __cplusplus
2017} // End extern "C"
2018#endif
2019
2020#if (__cplusplus >= 201103L || _MSVC_LANG >= 201103L)
2021
2022namespace {
2023template<typename T>
2024struct check_is_pointer;
2025template<typename T>
2026struct check_is_pointer<T *> {};
2027} // namespace
2028
2029/** Construct the halide equivalent of a C type */
2030template<typename T>
2031HALIDE_ALWAYS_INLINE halide_type_t halide_type_of() {
2032 // Create a compile-time error if T is not a pointer (without
2033 // using any includes - this code goes into the runtime).
2034 check_is_pointer<T> check;
2035 (void)check;
2036 return halide_type_t(halide_type_handle, 64);
2037}
2038
2039template<>
2040HALIDE_ALWAYS_INLINE halide_type_t halide_type_of<float>() {
2041 return halide_type_t(halide_type_float, 32);
2042}
2043
2044template<>
2045HALIDE_ALWAYS_INLINE halide_type_t halide_type_of<double>() {
2046 return halide_type_t(halide_type_float, 64);
2047}
2048
2049template<>
2050HALIDE_ALWAYS_INLINE halide_type_t halide_type_of<bool>() {
2051 return halide_type_t(halide_type_uint, 1);
2052}
2053
2054template<>
2055HALIDE_ALWAYS_INLINE halide_type_t halide_type_of<uint8_t>() {
2056 return halide_type_t(halide_type_uint, 8);
2057}
2058
2059template<>
2060HALIDE_ALWAYS_INLINE halide_type_t halide_type_of<uint16_t>() {
2061 return halide_type_t(halide_type_uint, 16);
2062}
2063
2064template<>
2065HALIDE_ALWAYS_INLINE halide_type_t halide_type_of<uint32_t>() {
2066 return halide_type_t(halide_type_uint, 32);
2067}
2068
2069template<>
2070HALIDE_ALWAYS_INLINE halide_type_t halide_type_of<uint64_t>() {
2071 return halide_type_t(halide_type_uint, 64);
2072}
2073
2074template<>
2075HALIDE_ALWAYS_INLINE halide_type_t halide_type_of<int8_t>() {
2076 return halide_type_t(halide_type_int, 8);
2077}
2078
2079template<>
2080HALIDE_ALWAYS_INLINE halide_type_t halide_type_of<int16_t>() {
2081 return halide_type_t(halide_type_int, 16);
2082}
2083
2084template<>
2085HALIDE_ALWAYS_INLINE halide_type_t halide_type_of<int32_t>() {
2086 return halide_type_t(halide_type_int, 32);
2087}
2088
2089template<>
2090HALIDE_ALWAYS_INLINE halide_type_t halide_type_of<int64_t>() {
2091 return halide_type_t(halide_type_int, 64);
2092}
2093
2094#endif // (__cplusplus >= 201103L || _MSVC_LANG >= 201103L)
2095
2096#endif // HALIDE_HALIDERUNTIME_H
2097
2098namespace Halide {
2099namespace Internal {
2100
2101/** A class representing a reference count to be used with IntrusivePtr */
2102class RefCount {
2103 std::atomic<int> count;
2104
2105public:
2106 RefCount() noexcept
2107 : count(0) {
2108 }
2109 int increment() {
2110 return ++count;
2111 } // Increment and return new value
2112 int decrement() {
2113 return --count;
2114 } // Decrement and return new value
2115 bool is_const_zero() const {
2116 return count == 0;
2117 }
2118};
2119
2120/**
2121 * Because in this header we don't yet know how client classes store
2122 * their RefCount (and we don't want to depend on the declarations of
2123 * the client classes), any class that you want to hold onto via one
2124 * of these must provide implementations of ref_count and destroy,
2125 * which we forward-declare here.
2126 *
2127 * E.g. if you want to use IntrusivePtr<MyClass>, then you should
2128 * define something like this in MyClass.cpp (assuming MyClass has
2129 * a field: mutable RefCount ref_count):
2130 *
2131 * template<> RefCount &ref_count<MyClass>(const MyClass *c) noexcept {return c->ref_count;}
2132 * template<> void destroy<MyClass>(const MyClass *c) {delete c;}
2133 */
2134// @{
2135template<typename T>
2136RefCount &ref_count(const T *t) noexcept;
2137template<typename T>
2138void destroy(const T *t);
2139// @}
2140
2141/** Intrusive shared pointers have a reference count (a
2142 * RefCount object) stored in the class itself. This is perhaps more
2143 * efficient than storing it externally, but more importantly, it
2144 * means it's possible to recover a reference-counted handle from the
2145 * raw pointer, and it's impossible to have two different reference
2146 * counts attached to the same raw object. Seeing as we pass around
2147 * raw pointers to concrete IRNodes and Expr's interchangeably, this
2148 * is a useful property.
2149 */
2150template<typename T>
2151struct IntrusivePtr {
2152private:
2153 void incref(T *p) {
2154 if (p) {
2155 ref_count(p).increment();
2156 }
2157 }
2158
2159 void decref(T *p) {
2160 if (p) {
2161 // Note that if the refcount is already zero, then we're
2162 // in a recursive destructor due to a self-reference (a
2163 // cycle), where the ref_count has been adjusted to remove
2164 // the counts due to the cycle. The next line then makes
2165 // the ref_count negative, which prevents actually
2166 // entering the destructor recursively.
2167 if (ref_count(p).decrement() == 0) {
2168 destroy(p);
2169 }
2170 }
2171 }
2172
2173protected:
2174 T *ptr = nullptr;
2175
2176public:
2177 /** Access the raw pointer in a variety of ways.
2178 * Note that a "const IntrusivePtr<T>" is not the same thing as an
2179 * IntrusivePtr<const T>. So the methods that return the ptr are
2180 * const, despite not adding an extra const to T. */
2181 // @{
2182 T *get() const {
2183 return ptr;
2184 }
2185
2186 T &operator*() const {
2187 return *ptr;
2188 }
2189
2190 T *operator->() const {
2191 return ptr;
2192 }
2193 // @}
2194
2195 ~IntrusivePtr() {
2196 decref(ptr);
2197 }
2198
2199 HALIDE_ALWAYS_INLINE
2200 IntrusivePtr() = default;
2201
2202 HALIDE_ALWAYS_INLINE
2203 IntrusivePtr(T *p)
2204 : ptr(p) {
2205 incref(ptr);
2206 }
2207
2208 HALIDE_ALWAYS_INLINE
2209 IntrusivePtr(const IntrusivePtr<T> &other) noexcept
2210 : ptr(other.ptr) {
2211 incref(ptr);
2212 }
2213
2214 HALIDE_ALWAYS_INLINE
2215 IntrusivePtr(IntrusivePtr<T> &&other) noexcept
2216 : ptr(other.ptr) {
2217 other.ptr = nullptr;
2218 }
2219
2220 // NOLINTNEXTLINE(bugprone-unhandled-self-assignment)
2221 IntrusivePtr<T> &operator=(const IntrusivePtr<T> &other) {
2222 // Same-ptr but different-this happens frequently enough
2223 // to check for (see https://github.com/halide/Halide/pull/5412)
2224 if (other.ptr == ptr) {
2225 return *this;
2226 }
2227 // Other can be inside of something owned by this, so we
2228 // should be careful to incref other before we decref
2229 // ourselves.
2230 T *temp = other.ptr;
2231 incref(temp);
2232 decref(ptr);
2233 ptr = temp;
2234 return *this;
2235 }
2236
2237 IntrusivePtr<T> &operator=(IntrusivePtr<T> &&other) noexcept {
2238 std::swap(ptr, other.ptr);
2239 return *this;
2240 }
2241
2242 /* Handles can be null. This checks that. */
2243 HALIDE_ALWAYS_INLINE
2244 bool defined() const {
2245 return ptr != nullptr;
2246 }
2247
2248 /* Check if two handles point to the same ptr. This is
2249 * equality of reference, not equality of value. */
2250 HALIDE_ALWAYS_INLINE
2251 bool same_as(const IntrusivePtr &other) const {
2252 return ptr == other.ptr;
2253 }
2254
2255 HALIDE_ALWAYS_INLINE
2256 bool operator<(const IntrusivePtr<T> &other) const {
2257 return ptr < other.ptr;
2258 }
2259};
2260
2261} // namespace Internal
2262} // namespace Halide
2263
2264#endif
2265#ifndef HALIDE_TYPE_H
2266#define HALIDE_TYPE_H
2267
2268#ifndef HALIDE_ERROR_H
2269#define HALIDE_ERROR_H
2270
2271#include <sstream>
2272#include <stdexcept>
2273
2274#ifndef HALIDE_DEBUG_H
2275#define HALIDE_DEBUG_H
2276
2277/** \file
2278 * Defines functions for debug logging during code generation.
2279 */
2280
2281#include <cstdlib>
2282#include <iostream>
2283#include <string>
2284
2285namespace Halide {
2286
2287struct Expr;
2288struct Type;
2289// Forward declare some things from IRPrinter, which we can't include yet.
2290std::ostream &operator<<(std::ostream &stream, const Expr &);
2291std::ostream &operator<<(std::ostream &stream, const Type &);
2292
2293class Module;
2294std::ostream &operator<<(std::ostream &stream, const Module &);
2295
2296struct Target;
2297/** Emit a halide Target in a human readable form */
2298std::ostream &operator<<(std::ostream &stream, const Target &);
2299
2300namespace Internal {
2301
2302struct Stmt;
2303std::ostream &operator<<(std::ostream &stream, const Stmt &);
2304
2305struct LoweredFunc;
2306std::ostream &operator<<(std::ostream &, const LoweredFunc &);
2307
2308/** For optional debugging during codegen, use the debug class as
2309 * follows:
2310 *
2311 \code
2312 debug(verbosity) << "The expression is " << expr << "\n";
2313 \endcode
2314 *
2315 * verbosity of 0 always prints, 1 should print after every major
2316 * stage, 2 should be used for more detail, and 3 should be used for
2317 * tracing everything that occurs. The verbosity with which to print
2318 * is determined by the value of the environment variable
2319 * HL_DEBUG_CODEGEN
2320 */
2321
2322class debug {
2323 const bool logging;
2324
2325public:
2326 debug(int verbosity)
2327 : logging(verbosity <= debug_level()) {
2328 }
2329
2330 template<typename T>
2331 debug &operator<<(T &&x) {
2332 if (logging) {
2333 std::cerr << std::forward<T>(x);
2334 }
2335 return *this;
2336 }
2337
2338 static int debug_level();
2339};
2340
2341} // namespace Internal
2342} // namespace Halide
2343
2344#endif
2345
2346namespace Halide {
2347
2348/** Query whether Halide was compiled with exceptions. */
2349bool exceptions_enabled();
2350
2351/** A base class for Halide errors. */
2352struct Error : public std::runtime_error {
2353 // Give each class a non-inlined constructor so that the type
2354 // doesn't get separately instantiated in each compilation unit.
2355 Error(const std::string &msg);
2356};
2357
2358/** An error that occurs while running a JIT-compiled Halide pipeline. */
2359struct RuntimeError : public Error {
2360 RuntimeError(const std::string &msg);
2361};
2362
2363/** An error that occurs while compiling a Halide pipeline that Halide
2364 * attributes to a user error. */
2365struct CompileError : public Error {
2366 CompileError(const std::string &msg);
2367};
2368
2369/** An error that occurs while compiling a Halide pipeline that Halide
2370 * attributes to an internal compiler bug, or to an invalid use of
2371 * Halide's internals. */
2372struct InternalError : public Error {
2373 InternalError(const std::string &msg);
2374};
2375
2376/** CompileTimeErrorReporter is used at compile time (*not* runtime) when
2377 * an error or warning is generated by Halide. Note that error() is called
2378 * a fatal error has occurred, and returning to Halide may cause a crash;
2379 * implementations of CompileTimeErrorReporter::error() should never return.
2380 * (Implementations of CompileTimeErrorReporter::warning() may return but
2381 * may also abort(), exit(), etc.)
2382 */
2383class CompileTimeErrorReporter {
2384public:
2385 virtual ~CompileTimeErrorReporter() = default;
2386 virtual void warning(const char *msg) = 0;
2387 virtual void error(const char *msg) = 0;
2388};
2389
2390/** The default error reporter logs to stderr, then throws an exception
2391 * (if HALIDE_WITH_EXCEPTIONS) or calls abort (if not). This allows customization
2392 * of that behavior if a more gentle response to error reporting is desired.
2393 * Note that error_reporter is expected to remain valid across all Halide usage;
2394 * it is up to the caller to ensure that this is the case (and to do any
2395 * cleanup necessary).
2396 */
2397void set_custom_compile_time_error_reporter(CompileTimeErrorReporter *error_reporter);
2398
2399namespace Internal {
2400
2401struct ErrorReport {
2402 enum {
2403 User = 0x0001,
2404 Warning = 0x0002,
2405 Runtime = 0x0004
2406 };
2407
2408 std::ostringstream msg;
2409 const int flags;
2410
2411 ErrorReport(const char *f, int l, const char *cs, int flags);
2412
2413 // Just a trick used to convert RValue into LValue
2414 HALIDE_ALWAYS_INLINE ErrorReport &ref() {
2415 return *this;
2416 }
2417
2418 template<typename T>
2419 ErrorReport &operator<<(const T &x) {
2420 msg << x;
2421 return *this;
2422 }
2423
2424 /** When you're done using << on the object, and let it fall out of
2425 * scope, this errors out, or throws an exception if they are
2426 * enabled. This is a little dangerous because the destructor will
2427 * also be called if there's an exception in flight due to an
2428 * error in one of the arguments passed to operator<<. We handle
2429 * this by only actually throwing if there isn't an exception in
2430 * flight already.
2431 */
2432#if __cplusplus >= 201100 || _MSC_VER >= 1900
2433 ~ErrorReport() noexcept(false);
2434#else
2435 ~ErrorReport();
2436#endif
2437};
2438
2439// This uses operator precedence as a trick to avoid argument evaluation if
2440// an assertion is true: it is intended to be used as part of the
2441// _halide_internal_assertion macro, to coerce the result of the stream
2442// expression to void (to match the condition-is-false case).
2443class Voidifier {
2444public:
2445 HALIDE_ALWAYS_INLINE Voidifier() = default;
2446 // This has to be an operator with a precedence lower than << but
2447 // higher than ?:
2448 HALIDE_ALWAYS_INLINE void operator&(ErrorReport &) {
2449 }
2450};
2451
2452/**
2453 * _halide_internal_assertion is used to implement our assertion macros
2454 * in such a way that the messages output for the assertion are only
2455 * evaluated if the assertion's value is false.
2456 *
2457 * Note that this macro intentionally has no parens internally; in actual
2458 * use, the implicit grouping will end up being
2459 *
2460 * condition ? (void) : (Voidifier() & (ErrorReport << arg1 << arg2 ... << argN))
2461 *
2462 * This (regrettably) requires a macro to work, but has the highly desirable
2463 * effect that all assertion parameters are totally skipped (not ever evaluated)
2464 * when the assertion is true.
2465 */
2466#define _halide_internal_assertion(condition, flags) \
2467 /* NOLINTNEXTLINE(bugprone-macro-parentheses) */ \
2468 (condition) ? (void)0 : ::Halide::Internal::Voidifier() & ::Halide::Internal::ErrorReport(__FILE__, __LINE__, #condition, flags).ref()
2469
2470#define internal_error Halide::Internal::ErrorReport(__FILE__, __LINE__, nullptr, 0)
2471#define user_error Halide::Internal::ErrorReport(__FILE__, __LINE__, nullptr, Halide::Internal::ErrorReport::User)
2472#define user_warning Halide::Internal::ErrorReport(__FILE__, __LINE__, nullptr, Halide::Internal::ErrorReport::User | Halide::Internal::ErrorReport::Warning)
2473#define halide_runtime_error Halide::Internal::ErrorReport(__FILE__, __LINE__, nullptr, Halide::Internal::ErrorReport::User | Halide::Internal::ErrorReport::Runtime)
2474
2475#define internal_assert(c) _halide_internal_assertion(c, 0)
2476#define user_assert(c) _halide_internal_assertion(c, Halide::Internal::ErrorReport::User)
2477
2478// The nicely named versions get cleaned up at the end of Halide.h,
2479// but user code might want to do halide-style user_asserts (e.g. the
2480// Extern macros introduce calls to user_assert), so for that purpose
2481// we define an equivalent macro that can be used outside of Halide.h
2482#define _halide_user_assert(c) _halide_internal_assertion(c, Halide::Internal::ErrorReport::User)
2483
2484// N.B. Any function that might throw a user_assert or user_error may
2485// not be inlined into the user's code, or the line number will be
2486// misattributed to Halide.h. Either make such functions internal to
2487// libHalide, or mark them as HALIDE_NO_USER_CODE_INLINE.
2488
2489} // namespace Internal
2490
2491} // namespace Halide
2492
2493#endif
2494#ifndef HALIDE_FLOAT16_H
2495#define HALIDE_FLOAT16_H
2496
2497#include <cstdint>
2498#include <string>
2499
2500namespace Halide {
2501
2502/** Class that provides a type that implements half precision
2503 * floating point (IEEE754 2008 binary16) in software.
2504 *
2505 * This type is enforced to be 16-bits wide and maintains no state
2506 * other than the raw IEEE754 binary16 bits so that it can passed
2507 * to code that checks a type's size and used for halide_buffer_t allocation.
2508 * */
2509struct float16_t {
2510
2511 static const int mantissa_bits = 10;
2512 static const uint16_t sign_mask = 0x8000;
2513 static const uint16_t exponent_mask = 0x7c00;
2514 static const uint16_t mantissa_mask = 0x03ff;
2515
2516 /// \name Constructors
2517 /// @{
2518
2519 /** Construct from a float, double, or int using
2520 * round-to-nearest-ties-to-even. Out-of-range values become +/-
2521 * infinity.
2522 */
2523 // @{
2524 explicit float16_t(float value);
2525 explicit float16_t(double value);
2526 explicit float16_t(int value);
2527 // @}
2528
2529 /** Construct a float16_t with the bits initialised to 0. This represents
2530 * positive zero.*/
2531 float16_t() = default;
2532
2533 /// @}
2534
2535 // Use explicit to avoid accidently raising the precision
2536 /** Cast to float */
2537 explicit operator float() const;
2538 /** Cast to double */
2539 explicit operator double() const;
2540 /** Cast to int */
2541 explicit operator int() const;
2542
2543 /** Get a new float16_t that represents a special value */
2544 // @{
2545 static float16_t make_zero();
2546 static float16_t make_negative_zero();
2547 static float16_t make_infinity();
2548 static float16_t make_negative_infinity();
2549 static float16_t make_nan();
2550 // @}
2551
2552 /** Get a new float16_t with the given raw bits
2553 *
2554 * \param bits The bits conformant to IEEE754 binary16
2555 */
2556 static float16_t make_from_bits(uint16_t bits);
2557
2558 /** Return a new float16_t with a negated sign bit*/
2559 float16_t operator-() const;
2560
2561 /** Arithmetic operators. */
2562 // @{
2563 float16_t operator+(float16_t rhs) const;
2564 float16_t operator-(float16_t rhs) const;
2565 float16_t operator*(float16_t rhs) const;
2566 float16_t operator/(float16_t rhs) const;
2567 float16_t operator+=(float16_t rhs) {
2568 return (*this = *this + rhs);
2569 }
2570 float16_t operator-=(float16_t rhs) {
2571 return (*this = *this - rhs);
2572 }
2573 float16_t operator*=(float16_t rhs) {
2574 return (*this = *this * rhs);
2575 }
2576 float16_t operator/=(float16_t rhs) {
2577 return (*this = *this / rhs);
2578 }
2579 // @}
2580
2581 /** Comparison operators */
2582 // @{
2583 bool operator==(float16_t rhs) const;
2584 bool operator!=(float16_t rhs) const {
2585 return !(*this == rhs);
2586 }
2587 bool operator>(float16_t rhs) const;
2588 bool operator<(float16_t rhs) const;
2589 bool operator>=(float16_t rhs) const {
2590 return (*this > rhs) || (*this == rhs);
2591 }
2592 bool operator<=(float16_t rhs) const {
2593 return (*this < rhs) || (*this == rhs);
2594 }
2595 // @}
2596
2597 /** Properties */
2598 // @{
2599 bool is_nan() const;
2600 bool is_infinity() const;
2601 bool is_negative() const;
2602 bool is_zero() const;
2603 // @}
2604
2605 /** Returns the bits that represent this float16_t.
2606 *
2607 * An alternative method to access the bits is to cast a pointer
2608 * to this instance as a pointer to a uint16_t.
2609 **/
2610 uint16_t to_bits() const;
2611
2612private:
2613 // The raw bits.
2614 uint16_t data = 0;
2615};
2616
2617static_assert(sizeof(float16_t) == 2, "float16_t should occupy two bytes");
2618
2619} // namespace Halide
2620
2621template<>
2622HALIDE_ALWAYS_INLINE halide_type_t halide_type_of<Halide::float16_t>() {
2623 return halide_type_t(halide_type_float, 16);
2624}
2625
2626namespace Halide {
2627
2628/** Class that provides a type that implements half precision
2629 * floating point using the bfloat16 format.
2630 *
2631 * This type is enforced to be 16-bits wide and maintains no state
2632 * other than the raw bits so that it can passed to code that checks
2633 * a type's size and used for halide_buffer_t allocation. */
2634struct bfloat16_t {
2635
2636 static const int mantissa_bits = 7;
2637 static const uint16_t sign_mask = 0x8000;
2638 static const uint16_t exponent_mask = 0x7f80;
2639 static const uint16_t mantissa_mask = 0x007f;
2640
2641 static const bfloat16_t zero, negative_zero, infinity, negative_infinity, nan;
2642
2643 /// \name Constructors
2644 /// @{
2645
2646 /** Construct from a float, double, or int using
2647 * round-to-nearest-ties-to-even. Out-of-range values become +/-
2648 * infinity.
2649 */
2650 // @{
2651 explicit bfloat16_t(float value);
2652 explicit bfloat16_t(double value);
2653 explicit bfloat16_t(int value);
2654 // @}
2655
2656 /** Construct a bfloat16_t with the bits initialised to 0. This represents
2657 * positive zero.*/
2658 bfloat16_t() = default;
2659
2660 /// @}
2661
2662 // Use explicit to avoid accidently raising the precision
2663 /** Cast to float */
2664 explicit operator float() const;
2665 /** Cast to double */
2666 explicit operator double() const;
2667 /** Cast to int */
2668 explicit operator int() const;
2669
2670 /** Get a new bfloat16_t that represents a special value */
2671 // @{
2672 static bfloat16_t make_zero();
2673 static bfloat16_t make_negative_zero();
2674 static bfloat16_t make_infinity();
2675 static bfloat16_t make_negative_infinity();
2676 static bfloat16_t make_nan();
2677 // @}
2678
2679 /** Get a new bfloat16_t with the given raw bits
2680 *
2681 * \param bits The bits conformant to IEEE754 binary16
2682 */
2683 static bfloat16_t make_from_bits(uint16_t bits);
2684
2685 /** Return a new bfloat16_t with a negated sign bit*/
2686 bfloat16_t operator-() const;
2687
2688 /** Arithmetic operators. */
2689 // @{
2690 bfloat16_t operator+(bfloat16_t rhs) const;
2691 bfloat16_t operator-(bfloat16_t rhs) const;
2692 bfloat16_t operator*(bfloat16_t rhs) const;
2693 bfloat16_t operator/(bfloat16_t rhs) const;
2694 bfloat16_t operator+=(bfloat16_t rhs) {
2695 return (*this = *this + rhs);
2696 }
2697 bfloat16_t operator-=(bfloat16_t rhs) {
2698 return (*this = *this - rhs);
2699 }
2700 bfloat16_t operator*=(bfloat16_t rhs) {
2701 return (*this = *this * rhs);
2702 }
2703 bfloat16_t operator/=(bfloat16_t rhs) {
2704 return (*this = *this / rhs);
2705 }
2706 // @}
2707
2708 /** Comparison operators */
2709 // @{
2710 bool operator==(bfloat16_t rhs) const;
2711 bool operator!=(bfloat16_t rhs) const {
2712 return !(*this == rhs);
2713 }
2714 bool operator>(bfloat16_t rhs) const;
2715 bool operator<(bfloat16_t rhs) const;
2716 bool operator>=(bfloat16_t rhs) const {
2717 return (*this > rhs) || (*this == rhs);
2718 }
2719 bool operator<=(bfloat16_t rhs) const {
2720 return (*this < rhs) || (*this == rhs);
2721 }
2722 // @}
2723
2724 /** Properties */
2725 // @{
2726 bool is_nan() const;
2727 bool is_infinity() const;
2728 bool is_negative() const;
2729 bool is_zero() const;
2730 // @}
2731
2732 /** Returns the bits that represent this bfloat16_t.
2733 *
2734 * An alternative method to access the bits is to cast a pointer
2735 * to this instance as a pointer to a uint16_t.
2736 **/
2737 uint16_t to_bits() const;
2738
2739private:
2740 // The raw bits.
2741 uint16_t data = 0;
2742};
2743
2744static_assert(sizeof(bfloat16_t) == 2, "bfloat16_t should occupy two bytes");
2745
2746} // namespace Halide
2747
2748template<>
2749HALIDE_ALWAYS_INLINE halide_type_t halide_type_of<Halide::bfloat16_t>() {
2750 return halide_type_t(halide_type_bfloat, 16);
2751}
2752
2753#endif
2754// Always use assert, even if llvm-config defines NDEBUG
2755#ifdef NDEBUG
2756#undef NDEBUG
2757#include <assert.h>
2758#define NDEBUG
2759#else
2760#include <cassert>
2761#endif
2762
2763#ifndef HALIDE_UTIL_H
2764#define HALIDE_UTIL_H
2765
2766/** \file
2767 * Various utility functions used internally Halide. */
2768
2769#include <cstdint>
2770#include <cstring>
2771#include <functional>
2772#include <limits>
2773#include <string>
2774#include <utility>
2775#include <vector>
2776
2777
2778#ifdef Halide_STATIC_DEFINE
2779#define HALIDE_EXPORT
2780#else
2781#if defined(_MSC_VER)
2782// Halide_EXPORTS is quietly defined by CMake when building a shared library
2783#ifdef Halide_EXPORTS
2784#define HALIDE_EXPORT __declspec(dllexport)
2785#else
2786#define HALIDE_EXPORT __declspec(dllimport)
2787#endif
2788#else
2789#define HALIDE_EXPORT __attribute__((visibility("default")))
2790#endif
2791#endif
2792
2793// If we're in user code, we don't want certain functions to be inlined.
2794#if defined(COMPILING_HALIDE) || defined(BUILDING_PYTHON)
2795#define HALIDE_NO_USER_CODE_INLINE
2796#else
2797#define HALIDE_NO_USER_CODE_INLINE HALIDE_NEVER_INLINE
2798#endif
2799
2800namespace Halide {
2801
2802/** Load a plugin in the form of a dynamic library (e.g. for custom autoschedulers).
2803 * If the string doesn't contain any . characters, the proper prefix and/or suffix
2804 * for the platform will be added:
2805 *
2806 * foo -> libfoo.so (Linux/OSX/etc -- note that .dylib is not supported)
2807 * foo -> foo.dll (Windows)
2808 *
2809 * otherwise, it is assumed to be an appropriate pathname.
2810 *
2811 * Any error in loading will assert-fail. */
2812void load_plugin(const std::string &lib_name);
2813
2814namespace Internal {
2815
2816/** Some numeric conversions are UB if the value won't fit in the result;
2817 * safe_numeric_cast<>() is meant as a drop-in replacement for a C/C++ cast
2818 * that adds well-defined behavior for the UB cases, attempting to mimic
2819 * common implementation behavior as much as possible.
2820 */
2821template<typename DST, typename SRC,
2822 typename std::enable_if<std::is_floating_point<SRC>::value>::type * = nullptr>
2823DST safe_numeric_cast(SRC s) {
2824 if (std::is_integral<DST>::value) {
2825 // Treat float -> int as a saturating cast; this is handled
2826 // in different ways by different compilers, so an arbitrary but safe
2827 // choice like this is reasonable.
2828 if (s < (SRC)std::numeric_limits<DST>::min()) {
2829 return std::numeric_limits<DST>::min();
2830 }
2831 if (s > (SRC)std::numeric_limits<DST>::max()) {
2832 return std::numeric_limits<DST>::max();
2833 }
2834 }
2835 return (DST)s;
2836}
2837
2838template<typename DST, typename SRC,
2839 typename std::enable_if<std::is_integral<SRC>::value>::type * = nullptr>
2840DST safe_numeric_cast(SRC s) {
2841 if (std::is_integral<DST>::value) {
2842 // any-int -> signed-int is technically UB if value won't fit;
2843 // in practice, common compilers implement such conversions as done below
2844 // (as verified by exhaustive testing on Clang for x86-64). We could
2845 // probably continue to rely on that behavior, but making it explicit
2846 // avoids possible wrather of UBSan and similar debug helpers.
2847 // (Yes, using sizeof for this comparison is a little odd for the uint->int
2848 // case, but the intent is to match existing common behavior, which this does.)
2849 if (std::is_integral<SRC>::value && std::is_signed<DST>::value && sizeof(DST) < sizeof(SRC)) {
2850 using UnsignedSrc = typename std::make_unsigned<SRC>::type;
2851 return (DST)(s & (UnsignedSrc)(-1));
2852 }
2853 }
2854 return (DST)s;
2855}
2856
2857/** An aggressive form of reinterpret cast used for correct type-punning. */
2858template<typename DstType, typename SrcType>
2859DstType reinterpret_bits(const SrcType &src) {
2860 static_assert(sizeof(SrcType) == sizeof(DstType), "Types must be same size");
2861 DstType dst;
2862 memcpy(&dst, &src, sizeof(SrcType));
2863 return dst;
2864}
2865
2866/** Make a unique name for an object based on the name of the stack
2867 * variable passed in. If introspection isn't working or there are no
2868 * debug symbols, just uses unique_name with the given prefix. */
2869std::string make_entity_name(void *stack_ptr, const std::string &type, char prefix);
2870
2871/** Get value of an environment variable. Returns its value
2872 * is defined in the environment. If the var is not defined, an empty string
2873 * is returned.
2874 */
2875std::string get_env_variable(char const *env_var_name);
2876
2877/** Get the name of the currently running executable. Platform-specific.
2878 * If program name cannot be retrieved, function returns an empty string. */
2879std::string running_program_name();
2880
2881/** Generate a unique name starting with the given prefix. It's unique
2882 * relative to all other strings returned by unique_name in this
2883 * process.
2884 *
2885 * The single-character version always appends a numeric suffix to the
2886 * character.
2887 *
2888 * The string version will either return the input as-is (with high
2889 * probability on the first time it is called with that input), or
2890 * replace any existing '$' characters with underscores, then add a
2891 * '$' sign and a numeric suffix to it.
2892 *
2893 * Note that unique_name('f') therefore differs from
2894 * unique_name("f"). The former returns something like f123, and the
2895 * latter returns either f or f$123.
2896 */
2897// @{
2898std::string unique_name(char prefix);
2899std::string unique_name(const std::string &prefix);
2900// @}
2901
2902/** Test if the first string starts with the second string */
2903bool starts_with(const std::string &str, const std::string &prefix);
2904
2905/** Test if the first string ends with the second string */
2906bool ends_with(const std::string &str, const std::string &suffix);
2907
2908/** Replace all matches of the second string in the first string with the last string */
2909std::string replace_all(const std::string &str, const std::string &find, const std::string &replace);
2910
2911/** Split the source string using 'delim' as the divider. */
2912std::vector<std::string> split_string(const std::string &source, const std::string &delim);
2913
2914/** Perform a left fold of a vector. Returns a default-constructed
2915 * vector element if the vector is empty. Similar to std::accumulate
2916 * but with a less clunky syntax. */
2917template<typename T, typename Fn>
2918T fold_left(const std::vector<T> &vec, Fn f) {
2919 T result;
2920 if (vec.empty()) {
2921 return result;
2922 }
2923 result = vec[0];
2924 for (size_t i = 1; i < vec.size(); i++) {
2925 result = f(result, vec[i]);
2926 }
2927 return result;
2928}
2929
2930/** Returns a right fold of a vector. Returns a default-constructed
2931 * vector element if the vector is empty. */
2932template<typename T, typename Fn>
2933T fold_right(const std::vector<T> &vec, Fn f) {
2934 T result;
2935 if (vec.empty()) {
2936 return result;
2937 }
2938 result = vec.back();
2939 for (size_t i = vec.size() - 1; i > 0; i--) {
2940 result = f(vec[i - 1], result);
2941 }
2942 return result;
2943}
2944
2945template<typename... T>
2946struct meta_and : std::true_type {};
2947
2948template<typename T1, typename... Args>
2949struct meta_and<T1, Args...> : std::integral_constant<bool, T1::value && meta_and<Args...>::value> {};
2950
2951template<typename... T>
2952struct meta_or : std::false_type {};
2953
2954template<typename T1, typename... Args>
2955struct meta_or<T1, Args...> : std::integral_constant<bool, T1::value || meta_or<Args...>::value> {};
2956
2957template<typename To, typename... Args>
2958struct all_are_convertible : meta_and<std::is_convertible<Args, To>...> {};
2959
2960/** Returns base name and fills in namespaces, outermost one first in vector. */
2961std::string extract_namespaces(const std::string &name, std::vector<std::string> &namespaces);
2962
2963/** Overload that returns base name only */
2964std::string extract_namespaces(const std::string &name);
2965
2966struct FileStat {
2967 uint64_t file_size;
2968 uint32_t mod_time; // Unix epoch time
2969 uint32_t uid;
2970 uint32_t gid;
2971 uint32_t mode;
2972};
2973
2974/** Create a unique file with a name of the form prefixXXXXXsuffix in an arbitrary
2975 * (but writable) directory; this is typically /tmp, but the specific
2976 * location is not guaranteed. (Note that the exact form of the file name
2977 * may vary; in particular, the suffix may be ignored on Windows.)
2978 * The file is created (but not opened), thus this can be called from
2979 * different threads (or processes, e.g. when building with parallel make)
2980 * without risking collision. Note that if this file is used as a temporary
2981 * file, the caller is responsibly for deleting it. Neither the prefix nor suffix
2982 * may contain a directory separator.
2983 */
2984std::string file_make_temp(const std::string &prefix, const std::string &suffix);
2985
2986/** Create a unique directory in an arbitrary (but writable) directory; this is
2987 * typically somewhere inside /tmp, but the specific location is not guaranteed.
2988 * The directory will be empty (i.e., this will never return /tmp itself,
2989 * but rather a new directory inside /tmp). The caller is responsible for removing the
2990 * directory after use.
2991 */
2992std::string dir_make_temp();
2993
2994/** Wrapper for access(). Quietly ignores errors. */
2995bool file_exists(const std::string &name);
2996
2997/** assert-fail if the file doesn't exist. useful primarily for testing purposes. */
2998void assert_file_exists(const std::string &name);
2999
3000/** assert-fail if the file DOES exist. useful primarily for testing purposes. */
3001void assert_no_file_exists(const std::string &name);
3002
3003/** Wrapper for unlink(). Asserts upon error. */
3004void file_unlink(const std::string &name);
3005
3006/** Wrapper for unlink(). Quietly ignores errors. */
3007void file_unlink(const std::string &name);
3008
3009/** Ensure that no file with this path exists. If such a file
3010 * exists and cannot be removed, assert-fail. */
3011void ensure_no_file_exists(const std::string &name);
3012
3013/** Wrapper for rmdir(). Asserts upon error. */
3014void dir_rmdir(const std::string &name);
3015
3016/** Wrapper for stat(). Asserts upon error. */
3017FileStat file_stat(const std::string &name);
3018
3019/** Read the entire contents of a file into a vector<char>. The file
3020 * is read in binary mode. Errors trigger an assertion failure. */
3021std::vector<char> read_entire_file(const std::string &pathname);
3022
3023/** Create or replace the contents of a file with a given pointer-and-length
3024 * of memory. If the file doesn't exist, it is created; if it does exist, it
3025 * is completely overwritten. Any error triggers an assertion failure. */
3026void write_entire_file(const std::string &pathname, const void *source, size_t source_len);
3027
3028inline void write_entire_file(const std::string &pathname, const std::vector<char> &source) {
3029 write_entire_file(pathname, source.data(), source.size());
3030}
3031
3032/** A simple utility class that creates a temporary file in its ctor and
3033 * deletes that file in its dtor; this is useful for temporary files that you
3034 * want to ensure are deleted when exiting a certain scope. Since this is essentially
3035 * just an RAII wrapper around file_make_temp() and file_unlink(), it has the same
3036 * failure modes (i.e.: assertion upon error).
3037 */
3038class TemporaryFile final {
3039public:
3040 TemporaryFile(const std::string &prefix, const std::string &suffix)
3041 : temp_path(file_make_temp(prefix, suffix)) {
3042 }
3043 const std::string &pathname() const {
3044 return temp_path;
3045 }
3046 ~TemporaryFile() {
3047 if (do_unlink) {
3048 file_unlink(temp_path);
3049 }
3050 }
3051 // You can call this if you want to defeat the automatic deletion;
3052 // this is rarely what you want to do (since it defeats the purpose
3053 // of this class), but can be quite handy for debugging purposes.
3054 void detach() {
3055 do_unlink = false;
3056 }
3057
3058private:
3059 const std::string temp_path;
3060 bool do_unlink = true;
3061
3062public:
3063 TemporaryFile(const TemporaryFile &) = delete;
3064 TemporaryFile &operator=(const TemporaryFile &) = delete;
3065 TemporaryFile(TemporaryFile &&) = delete;
3066 TemporaryFile &operator=(TemporaryFile &&) = delete;
3067};
3068
3069/** Routines to test if math would overflow for signed integers with
3070 * the given number of bits. */
3071// @{
3072bool add_would_overflow(int bits, int64_t a, int64_t b);
3073bool sub_would_overflow(int bits, int64_t a, int64_t b);
3074bool mul_would_overflow(int bits, int64_t a, int64_t b);
3075// @}
3076
3077/** Helper class for saving/restoring variable values on the stack, to allow
3078 * for early-exit that preserves correctness */
3079template<typename T>
3080struct ScopedValue {
3081 T &var;
3082 T old_value;
3083 /** Preserve the old value, restored at dtor time */
3084 ScopedValue(T &var)
3085 : var(var), old_value(var) {
3086 }
3087 /** Preserve the old value, then set the var to a new value. */
3088 ScopedValue(T &var, T new_value)
3089 : var(var), old_value(var) {
3090 var = new_value;
3091 }
3092 ~ScopedValue() {
3093 var = old_value;
3094 }
3095 operator T() const {
3096 return old_value;
3097 }
3098 // allow move but not copy
3099 ScopedValue(const ScopedValue &that) = delete;
3100 ScopedValue(ScopedValue &&that) noexcept = default;
3101};
3102
3103// Wrappers for some C++14-isms that are useful and trivially implementable
3104// in C++11; these are defined in the Halide::Internal namespace. If we
3105// are compiling under C++14 or later, we just use the standard implementations
3106// rather than our own.
3107#if __cplusplus >= 201402L
3108
3109// C++14: Use the standard implementations
3110using std::index_sequence;
3111using std::integer_sequence;
3112using std::make_index_sequence;
3113using std::make_integer_sequence;
3114
3115#else
3116
3117// C++11: std::integer_sequence (etc) is standard in C++14 but not C++11, but
3118// is easily written in C++11. This is a simple version that could
3119// probably be improved.
3120
3121template<typename T, T... Ints>
3122struct integer_sequence {
3123 static constexpr size_t size() {
3124 return sizeof...(Ints);
3125 }
3126};
3127
3128template<typename T>
3129struct next_integer_sequence;
3130
3131template<typename T, T... Ints>
3132struct next_integer_sequence<integer_sequence<T, Ints...>> {
3133 using type = integer_sequence<T, Ints..., sizeof...(Ints)>;
3134};
3135
3136template<typename T, T I, T N>
3137struct make_integer_sequence_helper {
3138 using type = typename next_integer_sequence<
3139 typename make_integer_sequence_helper<T, I + 1, N>::type>::type;
3140};
3141
3142template<typename T, T N>
3143struct make_integer_sequence_helper<T, N, N> {
3144 using type = integer_sequence<T>;
3145};
3146
3147template<typename T, T N>
3148using make_integer_sequence = typename make_integer_sequence_helper<T, 0, N>::type;
3149
3150template<size_t... Ints>
3151using index_sequence = integer_sequence<size_t, Ints...>;
3152
3153template<size_t N>
3154using make_index_sequence = make_integer_sequence<size_t, N>;
3155
3156#endif
3157
3158// Helpers for timing blocks of code. Put 'TIC;' at the start and
3159// 'TOC;' at the end. Timing is reported at the toc via
3160// debug(0). The calls can be nested and will pretty-print
3161// appropriately. Took this idea from matlab via Jon Barron.
3162//
3163// Note that this uses global state internally, and is not thread-safe
3164// at all. Only use it for single-threaded debugging sessions.
3165
3166void halide_tic_impl(const char *file, int line);
3167void halide_toc_impl(const char *file, int line);
3168#define HALIDE_TIC Halide::Internal::halide_tic_impl(__FILE__, __LINE__)
3169#define HALIDE_TOC Halide::Internal::halide_toc_impl(__FILE__, __LINE__)
3170#ifdef COMPILING_HALIDE
3171#define TIC HALIDE_TIC
3172#define TOC HALIDE_TOC
3173#endif
3174
3175// statically cast a value from one type to another: this is really just
3176// some syntactic sugar around static_cast<>() to avoid compiler warnings
3177// regarding 'bool' in some compliation configurations.
3178template<typename TO>
3179struct StaticCast {
3180 template<typename FROM, typename TO2 = TO, typename std::enable_if<!std::is_same<TO2, bool>::value>::type * = nullptr>
3181 inline constexpr static TO2 value(const FROM &from) {
3182 return static_cast<TO2>(from);
3183 }
3184
3185 template<typename FROM, typename TO2 = TO, typename std::enable_if<std::is_same<TO2, bool>::value>::type * = nullptr>
3186 inline constexpr static TO2 value(const FROM &from) {
3187 return from != 0;
3188 }
3189};
3190
3191// Like std::is_convertible, but with additional tests for arithmetic types:
3192// ensure that the value will roundtrip losslessly (e.g., no integer truncation
3193// or dropping of fractional parts).
3194template<typename TO>
3195struct IsRoundtrippable {
3196 template<typename FROM, typename TO2 = TO, typename std::enable_if<!std::is_convertible<FROM, TO>::value>::type * = nullptr>
3197 inline constexpr static bool value(const FROM &from) {
3198 return false;
3199 }
3200
3201 template<typename FROM, typename TO2 = TO, typename std::enable_if<std::is_convertible<FROM, TO>::value && std::is_arithmetic<TO>::value && std::is_arithmetic<FROM>::value && !std::is_same<TO, FROM>::value>::type * = nullptr>
3202 inline constexpr static bool value(const FROM &from) {
3203 return StaticCast<FROM>::value(StaticCast<TO>::value(from)) == from;
3204 }
3205
3206 template<typename FROM, typename TO2 = TO, typename std::enable_if<std::is_convertible<FROM, TO>::value && !(std::is_arithmetic<TO>::value && std::is_arithmetic<FROM>::value && !std::is_same<TO, FROM>::value)>::type * = nullptr>
3207 inline constexpr static bool value(const FROM &from) {
3208 return true;
3209 }
3210};
3211
3212/** Emit a version of a string that is a valid identifier in C (. is replaced with _) */
3213std::string c_print_name(const std::string &name);
3214
3215/** Return the LLVM_VERSION against which this libHalide is compiled. This is provided
3216 * only for internal tests which need to verify behavior; please don't use this outside
3217 * of Halide tests. */
3218int get_llvm_version();
3219
3220/** Call the given action in a platform-specific context that provides at least
3221 * 8MB of stack space. Currently only has any effect on Windows where it uses
3222 * a Fiber. */
3223void run_with_large_stack(const std::function<void()> &action);
3224
3225} // namespace Internal
3226} // namespace Halide
3227
3228#endif
3229#include <cstdint>
3230
3231/** \file
3232 * Defines halide types
3233 */
3234
3235/** A set of types to represent a C++ function signature. This allows
3236 * two things. First, proper prototypes can be provided for Halide
3237 * generated functions, giving better compile time type
3238 * checking. Second, C++ name mangling can be done to provide link
3239 * time type checking for both Halide generated functions and calls
3240 * from Halide to external functions.
3241 *
3242 * These are intended to be constexpr producable, but we don't depend
3243 * on C++11 yet. In C++14, it is possible these will be replaced with
3244 * introspection/reflection facilities.
3245 *
3246 * halide_handle_traits has to go outside the Halide namespace due to template
3247 * resolution rules. TODO(zalman): Do all types need to be in global namespace?
3248 */
3249//@{
3250
3251/** A structure to represent the (unscoped) name of a C++ composite type for use
3252 * as a single argument (or return value) in a function signature.
3253 *
3254 * Currently does not support the restrict qualifier, references, or
3255 * r-value references. These features cannot be used in extern
3256 * function calls from Halide or in the generated function from
3257 * Halide, but their applicability seems limited anyway.
3258 *
3259 * Although this is in the global namespace, it should be considered "Halide Internal"
3260 * and subject to change; code outside Halide should avoid referencing it.
3261 */
3262struct halide_cplusplus_type_name {
3263 /// An enum to indicate whether a C++ type is non-composite, a struct, class, or union
3264 enum CPPTypeType {
3265 Simple, ///< "int"
3266 Struct, ///< "struct Foo"
3267 Class, ///< "class Foo"
3268 Union, ///< "union Foo"
3269 Enum, ///< "enum Foo"
3270 } cpp_type_type; // Note: order is reflected in map_to_name table in CPlusPlusMangle.cpp
3271
3272 std::string name;
3273
3274 halide_cplusplus_type_name(CPPTypeType cpp_type_type, const std::string &name)
3275 : cpp_type_type(cpp_type_type), name(name) {
3276 }
3277
3278 bool operator==(const halide_cplusplus_type_name &rhs) const {
3279 return cpp_type_type == rhs.cpp_type_type &&
3280 name == rhs.name;
3281 }
3282
3283 bool operator!=(const halide_cplusplus_type_name &rhs) const {
3284 return !(*this == rhs);
3285 }
3286
3287 bool operator<(const halide_cplusplus_type_name &rhs) const {
3288 return cpp_type_type < rhs.cpp_type_type ||
3289 (cpp_type_type == rhs.cpp_type_type &&
3290 name < rhs.name);
3291 }
3292};
3293
3294/** A structure to represent the fully scoped name of a C++ composite
3295 * type for use in generating function signatures that use that type.
3296 *
3297 * This is intended to be a constexpr usable type, but we don't depend
3298 * on C++11 yet. In C++14, it is possible this will be replaced with
3299 * introspection/reflection facilities.
3300 *
3301 * Although this is in the global namespace, it should be considered "Halide Internal"
3302 * and subject to change; code outside Halide should avoid referencing it.
3303 */
3304struct halide_handle_cplusplus_type {
3305 halide_cplusplus_type_name inner_name;
3306 std::vector<std::string> namespaces;
3307 std::vector<halide_cplusplus_type_name> enclosing_types;
3308
3309 /// One set of modifiers on a type.
3310 /// The const/volatile/restrict propertises are "inside" the pointer property.
3311 enum Modifier : uint8_t {
3312 Const = 1 << 0, ///< Bitmask flag for "const"
3313 Volatile = 1 << 1, ///< Bitmask flag for "volatile"
3314 Restrict = 1 << 2, ///< Bitmask flag for "restrict"
3315 Pointer = 1 << 3, ///< Bitmask flag for a pointer "*"
3316 };
3317
3318 /// Qualifiers and indirections on type. 0 is innermost.
3319 std::vector<uint8_t> cpp_type_modifiers;
3320
3321 /// References are separate because they only occur at the outermost level.
3322 /// No modifiers are needed for references as they are not allowed to apply
3323 /// to the reference itself. (This isn't true for restrict, but that is a C++
3324 /// extension anyway.) If modifiers are needed, the last entry in the above
3325 /// array would be the modifers for the reference.
3326 enum ReferenceType : uint8_t {
3327 NotReference = 0,
3328 LValueReference = 1, // "&"
3329 RValueReference = 2, // "&&"
3330 };
3331 ReferenceType reference_type;
3332
3333 halide_handle_cplusplus_type(const halide_cplusplus_type_name &inner_name,
3334 const std::vector<std::string> &namespaces = {},
3335 const std::vector<halide_cplusplus_type_name> &enclosing_types = {},
3336 const std::vector<uint8_t> &modifiers = {},
3337 ReferenceType reference_type = NotReference)
3338 : inner_name(inner_name),
3339 namespaces(namespaces),
3340 enclosing_types(enclosing_types),
3341 cpp_type_modifiers(modifiers),
3342 reference_type(reference_type) {
3343 }
3344
3345 template<typename T>
3346 static halide_handle_cplusplus_type make();
3347};
3348//@}
3349
3350/** halide_c_type_to_name is a utility class used to provide a user-extensible
3351 * way of naming Handle types.
3352 *
3353 * Although this is in the global namespace, it should be considered "Halide Internal"
3354 * and subject to change; code outside Halide should avoid referencing it
3355 * directly (use the HALIDE_DECLARE_EXTERN_xxx macros instead).
3356 */
3357template<typename T>
3358struct halide_c_type_to_name {
3359 static constexpr bool known_type = false;
3360 static halide_cplusplus_type_name name() {
3361 return {halide_cplusplus_type_name::Simple, "void"};
3362 }
3363};
3364
3365#define HALIDE_DECLARE_EXTERN_TYPE(TypeType, Type) \
3366 template<> \
3367 struct halide_c_type_to_name<Type> { \
3368 static constexpr bool known_type = true; \
3369 static halide_cplusplus_type_name name() { \
3370 return {halide_cplusplus_type_name::TypeType, #Type}; \
3371 } \
3372 }
3373
3374#define HALIDE_DECLARE_EXTERN_SIMPLE_TYPE(T) HALIDE_DECLARE_EXTERN_TYPE(Simple, T)
3375#define HALIDE_DECLARE_EXTERN_STRUCT_TYPE(T) HALIDE_DECLARE_EXTERN_TYPE(Struct, T)
3376#define HALIDE_DECLARE_EXTERN_CLASS_TYPE(T) HALIDE_DECLARE_EXTERN_TYPE(Class, T)
3377#define HALIDE_DECLARE_EXTERN_UNION_TYPE(T) HALIDE_DECLARE_EXTERN_TYPE(Union, T)
3378
3379HALIDE_DECLARE_EXTERN_SIMPLE_TYPE(char);
3380HALIDE_DECLARE_EXTERN_SIMPLE_TYPE(bool);
3381HALIDE_DECLARE_EXTERN_SIMPLE_TYPE(int8_t);
3382HALIDE_DECLARE_EXTERN_SIMPLE_TYPE(uint8_t);
3383HALIDE_DECLARE_EXTERN_SIMPLE_TYPE(int16_t);
3384HALIDE_DECLARE_EXTERN_SIMPLE_TYPE(uint16_t);
3385HALIDE_DECLARE_EXTERN_SIMPLE_TYPE(int32_t);
3386HALIDE_DECLARE_EXTERN_SIMPLE_TYPE(uint32_t);
3387HALIDE_DECLARE_EXTERN_SIMPLE_TYPE(int64_t);
3388HALIDE_DECLARE_EXTERN_SIMPLE_TYPE(uint64_t);
3389HALIDE_DECLARE_EXTERN_SIMPLE_TYPE(Halide::float16_t);
3390HALIDE_DECLARE_EXTERN_SIMPLE_TYPE(Halide::bfloat16_t);
3391HALIDE_DECLARE_EXTERN_SIMPLE_TYPE(float);
3392HALIDE_DECLARE_EXTERN_SIMPLE_TYPE(double);
3393HALIDE_DECLARE_EXTERN_STRUCT_TYPE(halide_buffer_t);
3394HALIDE_DECLARE_EXTERN_STRUCT_TYPE(halide_dimension_t);
3395HALIDE_DECLARE_EXTERN_STRUCT_TYPE(halide_device_interface_t);
3396HALIDE_DECLARE_EXTERN_STRUCT_TYPE(halide_filter_metadata_t);
3397HALIDE_DECLARE_EXTERN_STRUCT_TYPE(halide_semaphore_t);
3398HALIDE_DECLARE_EXTERN_STRUCT_TYPE(halide_parallel_task_t);
3399
3400// You can make arbitrary user-defined types be "Known" using the
3401// macro above. This is useful for making Param<> arguments for
3402// Generators type safe. e.g.,
3403//
3404// struct MyFunStruct { ... };
3405//
3406// ...
3407//
3408// HALIDE_DECLARE_EXTERN_STRUCT_TYPE(MyFunStruct);
3409//
3410// ...
3411//
3412// class MyGenerator : public Generator<MyGenerator> {
3413// Param<const MyFunStruct *> my_struct_ptr;
3414// ...
3415// };
3416
3417template<typename T>
3418/*static*/ halide_handle_cplusplus_type halide_handle_cplusplus_type::make() {
3419 constexpr bool is_ptr = std::is_pointer<T>::value;
3420 constexpr bool is_lvalue_reference = std::is_lvalue_reference<T>::value;
3421 constexpr bool is_rvalue_reference = std::is_rvalue_reference<T>::value;
3422
3423 using TBase = typename std::remove_pointer<typename std::remove_reference<T>::type>::type;
3424 constexpr bool is_const = std::is_const<TBase>::value;
3425 constexpr bool is_volatile = std::is_volatile<TBase>::value;
3426
3427 constexpr uint8_t modifiers = static_cast<uint8_t>(
3428 (is_ptr ? halide_handle_cplusplus_type::Pointer : 0) |
3429 (is_const ? halide_handle_cplusplus_type::Const : 0) |
3430 (is_volatile ? halide_handle_cplusplus_type::Volatile : 0));
3431
3432 // clang-format off
3433 constexpr halide_handle_cplusplus_type::ReferenceType ref_type =
3434 (is_lvalue_reference ? halide_handle_cplusplus_type::LValueReference :
3435 is_rvalue_reference ? halide_handle_cplusplus_type::RValueReference :
3436 halide_handle_cplusplus_type::NotReference);
3437 // clang-format on
3438
3439 using TNonCVBase = typename std::remove_cv<TBase>::type;
3440 constexpr bool known_type = halide_c_type_to_name<TNonCVBase>::known_type;
3441 static_assert(!(!known_type && !is_ptr), "Unknown types must be pointers");
3442
3443 halide_handle_cplusplus_type info = {
3444 halide_c_type_to_name<TNonCVBase>::name(),
3445 {},
3446 {},
3447 {modifiers},
3448 ref_type};
3449 // Pull off any namespaces
3450 info.inner_name.name = Halide::Internal::extract_namespaces(info.inner_name.name, info.namespaces);
3451 return info;
3452}
3453
3454/** A type traits template to provide a halide_handle_cplusplus_type
3455 * value from a C++ type.
3456 *
3457 * Note the type represented is implicitly a pointer.
3458 *
3459 * A NULL pointer of type halide_handle_traits represents "void *".
3460 * This is chosen for compactness or representation as Type is a very
3461 * widely used data structure.
3462 *
3463 * Although this is in the global namespace, it should be considered "Halide Internal"
3464 * and subject to change; code outside Halide should avoid referencing it directly.
3465 */
3466template<typename T>
3467struct halide_handle_traits {
3468 // This trait must return a pointer to a global structure. I.e. it should never be freed.
3469 // A return value of nullptr here means "void *".
3470 HALIDE_ALWAYS_INLINE static const halide_handle_cplusplus_type *type_info() {
3471 if (std::is_pointer<T>::value ||
3472 std::is_lvalue_reference<T>::value ||
3473 std::is_rvalue_reference<T>::value) {
3474 static const halide_handle_cplusplus_type the_info = halide_handle_cplusplus_type::make<T>();
3475 return &the_info;
3476 }
3477 return nullptr;
3478 }
3479};
3480
3481namespace Halide {
3482
3483struct Expr;
3484
3485/** Types in the halide type system. They can be ints, unsigned ints,
3486 * or floats of various bit-widths (the 'bits' field). They can also
3487 * be vectors of the same (by setting the 'lanes' field to something
3488 * larger than one). Front-end code shouldn't use vector
3489 * types. Instead vectorize a function. */
3490struct Type {
3491private:
3492 halide_type_t type;
3493
3494public:
3495 /** Aliases for halide_type_code_t values for legacy compatibility
3496 * and to match the Halide internal C++ style. */
3497 // @{
3498 static const halide_type_code_t Int = halide_type_int;
3499 static const halide_type_code_t UInt = halide_type_uint;
3500 static const halide_type_code_t Float = halide_type_float;
3501 static const halide_type_code_t BFloat = halide_type_bfloat;
3502 static const halide_type_code_t Handle = halide_type_handle;
3503 // @}
3504
3505 /** The number of bytes required to store a single scalar value of this type. Ignores vector lanes. */
3506 int bytes() const {
3507 return (bits() + 7) / 8;
3508 }
3509
3510 // Default ctor initializes everything to predictable-but-unlikely values
3511 Type()
3512 : type(Handle, 0, 0) {
3513 }
3514
3515 /** Construct a runtime representation of a Halide type from:
3516 * code: The fundamental type from an enum.
3517 * bits: The bit size of one element.
3518 * lanes: The number of vector elements in the type. */
3519 Type(halide_type_code_t code, int bits, int lanes, const halide_handle_cplusplus_type *handle_type = nullptr)
3520 : type(code, (uint8_t)bits, (uint16_t)lanes), handle_type(handle_type) {
3521 }
3522
3523 /** Trivial copy constructor. */
3524 Type(const Type &that) = default;
3525
3526 /** Trivial copy assignment operator. */
3527 Type &operator=(const Type &that) = default;
3528
3529 /** Type is a wrapper around halide_type_t with more methods for use
3530 * inside the compiler. This simply constructs the wrapper around
3531 * the runtime value. */
3532 HALIDE_ALWAYS_INLINE
3533 Type(const halide_type_t &that, const halide_handle_cplusplus_type *handle_type = nullptr)
3534 : type(that), handle_type(handle_type) {
3535 }
3536
3537 /** Unwrap the runtime halide_type_t for use in runtime calls, etc.
3538 * Representation is exactly equivalent. */
3539 HALIDE_ALWAYS_INLINE
3540 operator halide_type_t() const {
3541 return type;
3542 }
3543
3544 /** Return the underlying data type of an element as an enum value. */
3545 HALIDE_ALWAYS_INLINE
3546 halide_type_code_t code() const {
3547 return (halide_type_code_t)type.code;
3548 }
3549
3550 /** Return the bit size of a single element of this type. */
3551 HALIDE_ALWAYS_INLINE
3552 int bits() const {
3553 return type.bits;
3554 }
3555
3556 /** Return the number of vector elements in this type. */
3557 HALIDE_ALWAYS_INLINE
3558 int lanes() const {
3559 return type.lanes;
3560 }
3561
3562 /** Return Type with same number of bits and lanes, but new_code for a type code. */
3563 Type with_code(halide_type_code_t new_code) const {
3564 return Type(new_code, bits(), lanes(),
3565 (new_code == code()) ? handle_type : nullptr);
3566 }
3567
3568 /** Return Type with same type code and lanes, but new_bits for the number of bits. */
3569 Type with_bits(int new_bits) const {
3570 return Type(code(), new_bits, lanes(),
3571 (new_bits == bits()) ? handle_type : nullptr);
3572 }
3573
3574 /** Return Type with same type code and number of bits,
3575 * but new_lanes for the number of vector lanes. */
3576 Type with_lanes(int new_lanes) const {
3577 return Type(code(), bits(), new_lanes, handle_type);
3578 }
3579
3580 /** Return Type with the same type code and number of lanes, but with twice as many bits. */
3581 Type widen() const {
3582 return with_bits(bits() * 2);
3583 }
3584
3585 /** Return Type with the same type code and number of lanes, but with half as many bits. */
3586 Type narrow() const {
3587 return with_bits(bits() / 2);
3588 }
3589
3590 /** Type to be printed when declaring handles of this type. */
3591 const halide_handle_cplusplus_type *handle_type = nullptr;
3592
3593 /** Is this type boolean (represented as UInt(1))? */
3594 HALIDE_ALWAYS_INLINE
3595 bool is_bool() const {
3596 return code() == UInt && bits() == 1;
3597 }
3598
3599 /** Is this type a vector type? (lanes() != 1).
3600 * TODO(abadams): Decide what to do for lanes() == 0. */
3601 HALIDE_ALWAYS_INLINE
3602 bool is_vector() const {
3603 return lanes() != 1;
3604 }
3605
3606 /** Is this type a scalar type? (lanes() == 1).
3607 * TODO(abadams): Decide what to do for lanes() == 0. */
3608 HALIDE_ALWAYS_INLINE
3609 bool is_scalar() const {
3610 return lanes() == 1;
3611 }
3612
3613 /** Is this type a floating point type (float or double). */
3614 HALIDE_ALWAYS_INLINE
3615 bool is_float() const {
3616 return code() == Float || code() == BFloat;
3617 }
3618
3619 /** Is this type a floating point type (float or double). */
3620 HALIDE_ALWAYS_INLINE
3621 bool is_bfloat() const {
3622 return code() == BFloat;
3623 }
3624
3625 /** Is this type a signed integer type? */
3626 HALIDE_ALWAYS_INLINE
3627 bool is_int() const {
3628 return code() == Int;
3629 }
3630
3631 /** Is this type an unsigned integer type? */
3632 HALIDE_ALWAYS_INLINE
3633 bool is_uint() const {
3634 return code() == UInt;
3635 }
3636
3637 /** Is this type an integer type of any sort? */
3638 HALIDE_ALWAYS_INLINE
3639 bool is_int_or_uint() const {
3640 return code() == Int || code() == UInt;
3641 }
3642
3643 /** Is this type an opaque handle type (void *) */
3644 HALIDE_ALWAYS_INLINE
3645 bool is_handle() const {
3646 return code() == Handle;
3647 }
3648
3649 // Returns true iff type is a signed integral type where overflow is defined.
3650 HALIDE_ALWAYS_INLINE
3651 bool can_overflow_int() const {
3652 return is_int() && bits() <= 16;
3653 }
3654
3655 // Returns true iff type does have a well-defined overflow behavior.
3656 HALIDE_ALWAYS_INLINE
3657 bool can_overflow() const {
3658 return is_uint() || can_overflow_int();
3659 }
3660
3661 /** Check that the type name of two handles matches. */
3662 bool same_handle_type(const Type &other) const;
3663
3664 /** Compare two types for equality */
3665 bool operator==(const Type &other) const {
3666 return type == other.type && (code() != Handle || same_handle_type(other));
3667 }
3668
3669 /** Compare two types for inequality */
3670 bool operator!=(const Type &other) const {
3671 return type != other.type || (code() == Handle && !same_handle_type(other));
3672 }
3673
3674 /** Compare ordering of two types so they can be used in certain containers and algorithms */
3675 bool operator<(const Type &other) const {
3676 if (type < other.type) {
3677 return true;
3678 }
3679 if (code() == Handle) {
3680 return handle_type < other.handle_type;
3681 }
3682 return false;
3683 }
3684
3685 /** Produce the scalar type (that of a single element) of this vector type */
3686 Type element_of() const {
3687 return with_lanes(1);
3688 }
3689
3690 /** Can this type represent all values of another type? */
3691 bool can_represent(Type other) const;
3692
3693 /** Can this type represent a particular constant? */
3694 // @{
3695 bool can_represent(double x) const;
3696 bool can_represent(int64_t x) const;
3697 bool can_represent(uint64_t x) const;
3698 // @}
3699
3700 /** Check if an integer constant value is the maximum or minimum
3701 * representable value for this type. */
3702 // @{
3703 bool is_max(uint64_t) const;
3704 bool is_max(int64_t) const;
3705 bool is_min(uint64_t) const;
3706 bool is_min(int64_t) const;
3707 // @}
3708
3709 /** Return an expression which is the maximum value of this type.
3710 * Returns infinity for types which can represent it. */
3711 Expr max() const;
3712
3713 /** Return an expression which is the minimum value of this type.
3714 * Returns -infinity for types which can represent it. */
3715 Expr min() const;
3716};
3717
3718/** Constructing a signed integer type */
3719inline Type Int(int bits, int lanes = 1) {
3720 return Type(Type::Int, bits, lanes);
3721}
3722
3723/** Constructing an unsigned integer type */
3724inline Type UInt(int bits, int lanes = 1) {
3725 return Type(Type::UInt, bits, lanes);
3726}
3727
3728/** Construct a floating-point type */
3729inline Type Float(int bits, int lanes = 1) {
3730 return Type(Type::Float, bits, lanes);
3731}
3732
3733/** Construct a floating-point type in the bfloat format. Only 16-bit currently supported. */
3734inline Type BFloat(int bits, int lanes = 1) {
3735 return Type(Type::BFloat, bits, lanes);
3736}
3737
3738/** Construct a boolean type */
3739inline Type Bool(int lanes = 1) {
3740 return UInt(1, lanes);
3741}
3742
3743/** Construct a handle type */
3744inline Type Handle(int lanes = 1, const halide_handle_cplusplus_type *handle_type = nullptr) {
3745 return Type(Type::Handle, 64, lanes, handle_type);
3746}
3747
3748/** Construct the halide equivalent of a C type */
3749template<typename T>
3750inline Type type_of() {
3751 return Type(halide_type_of<T>(), halide_handle_traits<T>::type_info());
3752}
3753
3754/** Halide type to a C++ type */
3755std::string type_to_c_type(Type type, bool include_space, bool c_plus_plus = true);
3756
3757} // namespace Halide
3758
3759#endif
3760
3761namespace Halide {
3762
3763struct bfloat16_t;
3764struct float16_t;
3765
3766namespace Internal {
3767
3768class IRMutator;
3769class IRVisitor;
3770
3771/** All our IR node types get unique IDs for the purposes of RTTI */
3772enum class IRNodeType {
3773 // Exprs, in order of strength. Code in IRMatch.h and the
3774 // simplifier relies on this order for canonicalization of
3775 // expressions, so you may need to update those modules if you
3776 // change this list.
3777 IntImm,
3778 UIntImm,
3779 FloatImm,
3780 StringImm,
3781 Broadcast,
3782 Cast,
3783 Variable,
3784 Add,
3785 Sub,
3786 Mod,
3787 Mul,
3788 Div,
3789 Min,
3790 Max,
3791 EQ,
3792 NE,
3793 LT,
3794 LE,
3795 GT,
3796 GE,
3797 And,
3798 Or,
3799 Not,
3800 Select,
3801 Load,
3802 Ramp,
3803 Call,
3804 Let,
3805 Shuffle,
3806 VectorReduce,
3807 // Stmts
3808 LetStmt,
3809 AssertStmt,
3810 ProducerConsumer,
3811 For,
3812 Acquire,
3813 Store,
3814 Provide,
3815 Allocate,
3816 Free,
3817 Realize,
3818 Block,
3819 Fork,
3820 IfThenElse,
3821 Evaluate,
3822 Prefetch,
3823 Atomic
3824};
3825
3826constexpr IRNodeType StrongestExprNodeType = IRNodeType::VectorReduce;
3827
3828/** The abstract base classes for a node in the Halide IR. */
3829struct IRNode {
3830
3831 /** We use the visitor pattern to traverse IR nodes throughout the
3832 * compiler, so we have a virtual accept method which accepts
3833 * visitors.
3834 */
3835 virtual void accept(IRVisitor *v) const = 0;
3836 IRNode(IRNodeType t)
3837 : node_type(t) {
3838 }
3839 virtual ~IRNode() = default;
3840
3841 /** These classes are all managed with intrusive reference
3842 * counting, so we also track a reference count. It's mutable
3843 * so that we can do reference counting even through const
3844 * references to IR nodes.
3845 */
3846 mutable RefCount ref_count;
3847
3848 /** Each IR node subclass has a unique identifier. We can compare
3849 * these values to do runtime type identification. We don't
3850 * compile with rtti because that injects run-time type
3851 * identification stuff everywhere (and often breaks when linking
3852 * external libraries compiled without it), and we only want it
3853 * for IR nodes. One might want to put this value in the vtable,
3854 * but that adds another level of indirection, and for Exprs we
3855 * have 32 free bits in between the ref count and the Type
3856 * anyway, so this doesn't increase the memory footprint of an IR node.
3857 */
3858 IRNodeType node_type;
3859};
3860
3861template<>
3862inline RefCount &ref_count<IRNode>(const IRNode *t) noexcept {
3863 return t->ref_count;
3864}
3865
3866template<>
3867inline void destroy<IRNode>(const IRNode *t) {
3868 delete t;
3869}
3870
3871/** IR nodes are split into expressions and statements. These are
3872 similar to expressions and statements in C - expressions
3873 represent some value and have some type (e.g. x + 3), and
3874 statements are side-effecting pieces of code that do not
3875 represent a value (e.g. assert(x > 3)) */
3876
3877/** A base class for statement nodes. They have no properties or
3878 methods beyond base IR nodes for now. */
3879struct BaseStmtNode : public IRNode {
3880 BaseStmtNode(IRNodeType t)
3881 : IRNode(t) {
3882 }
3883 virtual Stmt mutate_stmt(IRMutator *v) const = 0;
3884};
3885
3886/** A base class for expression nodes. They all contain their types
3887 * (e.g. Int(32), Float(32)) */
3888struct BaseExprNode : public IRNode {
3889 BaseExprNode(IRNodeType t)
3890 : IRNode(t) {
3891 }
3892 virtual Expr mutate_expr(IRMutator *v) const = 0;
3893 Type type;
3894};
3895
3896/** We use the "curiously recurring template pattern" to avoid
3897 duplicated code in the IR Nodes. These classes live between the
3898 abstract base classes and the actual IR Nodes in the
3899 inheritance hierarchy. It provides an implementation of the
3900 accept function necessary for the visitor pattern to work, and
3901 a concrete instantiation of a unique IRNodeType per class. */
3902template<typename T>
3903struct ExprNode : public BaseExprNode {
3904 void accept(IRVisitor *v) const override;
3905 Expr mutate_expr(IRMutator *v) const override;
3906 ExprNode()
3907 : BaseExprNode(T::_node_type) {
3908 }
3909 ~ExprNode() override = default;
3910};
3911
3912template<typename T>
3913struct StmtNode : public BaseStmtNode {
3914 void accept(IRVisitor *v) const override;
3915 Stmt mutate_stmt(IRMutator *v) const override;
3916 StmtNode()
3917 : BaseStmtNode(T::_node_type) {
3918 }
3919 ~StmtNode() override = default;
3920};
3921
3922/** IR nodes are passed around opaque handles to them. This is a
3923 base class for those handles. It manages the reference count,
3924 and dispatches visitors. */
3925struct IRHandle : public IntrusivePtr<const IRNode> {
3926 HALIDE_ALWAYS_INLINE
3927 IRHandle() = default;
3928
3929 HALIDE_ALWAYS_INLINE
3930 IRHandle(const IRNode *p)
3931 : IntrusivePtr<const IRNode>(p) {
3932 }
3933
3934 /** Dispatch to the correct visitor method for this node. E.g. if
3935 * this node is actually an Add node, then this will call
3936 * IRVisitor::visit(const Add *) */
3937 void accept(IRVisitor *v) const {
3938 ptr->accept(v);
3939 }
3940
3941 /** Downcast this ir node to its actual type (e.g. Add, or
3942 * Select). This returns nullptr if the node is not of the requested
3943 * type. Example usage:
3944 *
3945 * if (const Add *add = node->as<Add>()) {
3946 * // This is an add node
3947 * }
3948 */
3949 template<typename T>
3950 const T *as() const {
3951 if (ptr && ptr->node_type == T::_node_type) {
3952 return (const T *)ptr;
3953 }
3954 return nullptr;
3955 }
3956
3957 IRNodeType node_type() const {
3958 return ptr->node_type;
3959 }
3960};
3961
3962/** Integer constants */
3963struct IntImm : public ExprNode<IntImm> {
3964 int64_t value;
3965
3966 static const IntImm *make(Type t, int64_t value);
3967
3968 static const IRNodeType _node_type = IRNodeType::IntImm;
3969};
3970
3971/** Unsigned integer constants */
3972struct UIntImm : public ExprNode<UIntImm> {
3973 uint64_t value;
3974
3975 static const UIntImm *make(Type t, uint64_t value);
3976
3977 static const IRNodeType _node_type = IRNodeType::UIntImm;
3978};
3979
3980/** Floating point constants */
3981struct FloatImm : public ExprNode<FloatImm> {
3982 double value;
3983
3984 static const FloatImm *make(Type t, double value);
3985
3986 static const IRNodeType _node_type = IRNodeType::FloatImm;
3987};
3988
3989/** String constants */
3990struct StringImm : public ExprNode<StringImm> {
3991 std::string value;
3992
3993 static const StringImm *make(const std::string &val);
3994
3995 static const IRNodeType _node_type = IRNodeType::StringImm;
3996};
3997
3998} // namespace Internal
3999
4000/** A fragment of Halide syntax. It's implemented as reference-counted
4001 * handle to a concrete expression node, but it's immutable, so you
4002 * can treat it as a value type. */
4003struct Expr : public Internal::IRHandle {
4004 /** Make an undefined expression */
4005 HALIDE_ALWAYS_INLINE
4006 Expr() = default;
4007
4008 /** Make an expression from a concrete expression node pointer (e.g. Add) */
4009 HALIDE_ALWAYS_INLINE
4010 Expr(const Internal::BaseExprNode *n)
4011 : IRHandle(n) {
4012 }
4013
4014 /** Make an expression representing numeric constants of various types. */
4015 // @{
4016 explicit Expr(int8_t x)
4017 : IRHandle(Internal::IntImm::make(Int(8), x)) {
4018 }
4019 explicit Expr(int16_t x)
4020 : IRHandle(Internal::IntImm::make(Int(16), x)) {
4021 }
4022 Expr(int32_t x)
4023 : IRHandle(Internal::IntImm::make(Int(32), x)) {
4024 }
4025 explicit Expr(int64_t x)
4026 : IRHandle(Internal::IntImm::make(Int(64), x)) {
4027 }
4028 explicit Expr(uint8_t x)
4029 : IRHandle(Internal::UIntImm::make(UInt(8), x)) {
4030 }
4031 explicit Expr(uint16_t x)
4032 : IRHandle(Internal::UIntImm::make(UInt(16), x)) {
4033 }
4034 explicit Expr(uint32_t x)
4035 : IRHandle(Internal::UIntImm::make(UInt(32), x)) {
4036 }
4037 explicit Expr(uint64_t x)
4038 : IRHandle(Internal::UIntImm::make(UInt(64), x)) {
4039 }
4040 Expr(float16_t x)
4041 : IRHandle(Internal::FloatImm::make(Float(16), (double)x)) {
4042 }
4043 Expr(bfloat16_t x)
4044 : IRHandle(Internal::FloatImm::make(BFloat(16), (double)x)) {
4045 }
4046 Expr(float x)
4047 : IRHandle(Internal::FloatImm::make(Float(32), x)) {
4048 }
4049 explicit Expr(double x)
4050 : IRHandle(Internal::FloatImm::make(Float(64), x)) {
4051 }
4052 // @}
4053
4054 /** Make an expression representing a const string (i.e. a StringImm) */
4055 Expr(const std::string &s)
4056 : IRHandle(Internal::StringImm::make(s)) {
4057 }
4058
4059 /** Override get() to return a BaseExprNode * instead of an IRNode * */
4060 HALIDE_ALWAYS_INLINE
4061 const Internal::BaseExprNode *get() const {
4062 return (const Internal::BaseExprNode *)ptr;
4063 }
4064
4065 /** Get the type of this expression node */
4066 HALIDE_ALWAYS_INLINE
4067 Type type() const {
4068 return get()->type;
4069 }
4070};
4071
4072/** This lets you use an Expr as a key in a map of the form
4073 * map<Expr, Foo, ExprCompare> */
4074struct ExprCompare {
4075 bool operator()(const Expr &a, const Expr &b) const {
4076 return a.get() < b.get();
4077 }
4078};
4079
4080/** A single-dimensional span. Includes all numbers between min and
4081 * (min + extent - 1). */
4082struct Range {
4083 Expr min, extent;
4084
4085 Range() = default;
4086 Range(const Expr &min_in, const Expr &extent_in);
4087};
4088
4089/** A multi-dimensional box. The outer product of the elements */
4090typedef std::vector<Range> Region;
4091
4092/** An enum describing different address spaces to be used with Func::store_in. */
4093enum class MemoryType {
4094 /** Let Halide select a storage type automatically */
4095 Auto,
4096
4097 /** Heap/global memory. Allocated using halide_malloc, or
4098 * halide_device_malloc */
4099 Heap,
4100
4101 /** Stack memory. Allocated using alloca. Requires a constant
4102 * size. Corresponds to per-thread local memory on the GPU. If all
4103 * accesses are at constant coordinates, may be promoted into the
4104 * register file at the discretion of the register allocator. */
4105 Stack,
4106
4107 /** Register memory. The allocation should be promoted into the
4108 * register file. All stores must be at constant coordinates. May
4109 * be spilled to the stack at the discretion of the register
4110 * allocator. */
4111 Register,
4112
4113 /** Allocation is stored in GPU shared memory. Also known as
4114 * "local" in OpenCL, and "threadgroup" in metal. Can be shared
4115 * across GPU threads within the same block. */
4116 GPUShared,
4117
4118 /** Allocation is stored in GPU texture memory and accessed through
4119 * hardware sampler */
4120 GPUTexture,
4121
4122 /** Allocate Locked Cache Memory to act as local memory */
4123 LockedCache,
4124 /** Vector Tightly Coupled Memory. HVX (Hexagon) local memory available on
4125 * v65+. This memory has higher performance and lower power. Ideal for
4126 * intermediate buffers. Necessary for vgather-vscatter instructions
4127 * on Hexagon */
4128 VTCM,
4129};
4130
4131namespace Internal {
4132
4133/** An enum describing a type of loop traversal. Used in schedules,
4134 * and in the For loop IR node. Serial is a conventional ordered for
4135 * loop. Iterations occur in increasing order, and each iteration must
4136 * appear to have finished before the next begins. Parallel, GPUBlock,
4137 * and GPUThread are parallel and unordered: iterations may occur in
4138 * any order, and multiple iterations may occur
4139 * simultaneously. Vectorized and GPULane are parallel and
4140 * synchronous: they act as if all iterations occur at the same time
4141 * in lockstep. */
4142enum class ForType {
4143 Serial,
4144 Parallel,
4145 Vectorized,
4146 Unrolled,
4147 Extern,
4148 GPUBlock,
4149 GPUThread,
4150 GPULane,
4151};
4152
4153/** Check if for_type executes for loop iterations in parallel and unordered. */
4154bool is_unordered_parallel(ForType for_type);
4155
4156/** Returns true if for_type executes for loop iterations in parallel. */
4157bool is_parallel(ForType for_type);
4158
4159/** A reference-counted handle to a statement node. */
4160struct Stmt : public IRHandle {
4161 Stmt() = default;
4162 Stmt(const BaseStmtNode *n)
4163 : IRHandle(n) {
4164 }
4165
4166 /** Override get() to return a BaseStmtNode * instead of an IRNode * */
4167 HALIDE_ALWAYS_INLINE
4168 const BaseStmtNode *get() const {
4169 return (const Internal::BaseStmtNode *)ptr;
4170 }
4171
4172 /** This lets you use a Stmt as a key in a map of the form
4173 * map<Stmt, Foo, Stmt::Compare> */
4174 struct Compare {
4175 bool operator()(const Stmt &a, const Stmt &b) const {
4176 return a.ptr < b.ptr;
4177 }
4178 };
4179};
4180
4181} // namespace Internal
4182} // namespace Halide
4183
4184#endif
4185#include <map>
4186
4187/** \file
4188 * Defines the lowering pass that insert mutex allocation code & locks
4189 * for the atomic nodes that require mutex locks. It also checks whether
4190 * the atomic operation is valid. It rejects algorithms that have indexing
4191 * on left-hand-side which references the buffer itself, e.g.
4192 * f(clamp(f(r), 0, 100)) = f(r) + 1
4193 * If the SplitTuple pass does not lift out the Provide value as a let
4194 * expression. This is confirmed by checking whether the Provide nodes
4195 * inside an Atomic node have let binding values accessing the buffers
4196 * inside the atomic node.
4197 * Finally, it lifts the store indexing expressions inside the atomic node
4198 * outside of the atomic to avoid side-effects inside those expressions
4199 * being evaluated twice. */
4200
4201namespace Halide {
4202namespace Internal {
4203
4204class Function;
4205
4206Stmt add_atomic_mutex(Stmt s, const std::map<std::string, Function> &env);
4207
4208} // namespace Internal
4209} // namespace Halide
4210
4211#endif
4212#ifndef HALIDE_INTERNAL_ADD_IMAGE_CHECKS_H
4213#define HALIDE_INTERNAL_ADD_IMAGE_CHECKS_H
4214
4215/** \file
4216 *
4217 * Defines the lowering pass that adds the assertions that validate
4218 * input and output buffers.
4219 */
4220#include <map>
4221#include <string>
4222#include <vector>
4223
4224#ifndef HALIDE_BOUNDS_H
4225#define HALIDE_BOUNDS_H
4226
4227/** \file
4228 * Methods for computing the upper and lower bounds of an expression,
4229 * and the regions of a function read or written by a statement.
4230 */
4231
4232#ifndef HALIDE_INTERVAL_H
4233#define HALIDE_INTERVAL_H
4234
4235/** \file
4236 * Defines the Interval class
4237 */
4238
4239
4240namespace Halide {
4241namespace Internal {
4242
4243/** A class to represent ranges of Exprs. Can be unbounded above or below. */
4244struct Interval {
4245
4246 /** Exprs to represent positive and negative infinity */
4247#ifdef COMPILING_HALIDE
4248 static HALIDE_ALWAYS_INLINE Expr pos_inf() {
4249 return pos_inf_expr;
4250 }
4251 static HALIDE_ALWAYS_INLINE Expr neg_inf() {
4252 return neg_inf_expr;
4253 }
4254#else
4255 static Expr pos_inf() {
4256 return pos_inf_noinline();
4257 }
4258 static Expr neg_inf() {
4259 return neg_inf_noinline();
4260 }
4261#endif
4262
4263 /** The lower and upper bound of the interval. They are included
4264 * in the interval. */
4265 Expr min, max;
4266
4267 /** A default-constructed Interval is everything */
4268 Interval()
4269 : min(neg_inf()), max(pos_inf()) {
4270 }
4271
4272 /** Construct an interval from a lower and upper bound. */
4273 Interval(const Expr &min, const Expr &max)
4274 : min(min), max(max) {
4275 internal_assert(min.defined() && max.defined());
4276 }
4277
4278 /** The interval representing everything. */
4279 static Interval everything();
4280
4281 /** The interval representing nothing. */
4282 static Interval nothing();
4283
4284 /** Construct an interval representing a single point */
4285 static Interval single_point(const Expr &e);
4286
4287 /** Is the interval the empty set */
4288 bool is_empty() const;
4289
4290 /** Is the interval the entire range */
4291 bool is_everything() const;
4292
4293 /** Is the interval just a single value (min == max) */
4294 bool is_single_point() const;
4295
4296 /** Is the interval a particular single value */
4297 bool is_single_point(const Expr &e) const;
4298
4299 /** Does the interval have a finite least upper bound */
4300 bool has_upper_bound() const;
4301
4302 /** Does the interval have a finite greatest lower bound */
4303 bool has_lower_bound() const;
4304
4305 /** Does the interval have a finite upper and lower bound */
4306 bool is_bounded() const;
4307
4308 /** Is the interval the same as another interval */
4309 bool same_as(const Interval &other) const;
4310
4311 /** Expand the interval to include another Interval */
4312 void include(const Interval &i);
4313
4314 /** Expand the interval to include an Expr */
4315 void include(const Expr &e);
4316
4317 /** Construct the smallest interval containing two intervals. */
4318 static Interval make_union(const Interval &a, const Interval &b);
4319
4320 /** Construct the largest interval contained within two intervals. */
4321 static Interval make_intersection(const Interval &a, const Interval &b);
4322
4323 /** An eagerly-simplifying max of two Exprs that respects infinities. */
4324 static Expr make_max(const Expr &a, const Expr &b);
4325
4326 /** An eagerly-simplifying min of two Exprs that respects infinities. */
4327 static Expr make_min(const Expr &a, const Expr &b);
4328
4329 /** Equivalent to same_as. Exists so that the autoscheduler can
4330 * compare two map<string, Interval> for equality in order to
4331 * cache computations. */
4332 bool operator==(const Interval &other) const;
4333
4334private:
4335 static Expr neg_inf_expr, pos_inf_expr;
4336
4337 // Never used inside libHalide; provided for Halide tests, to avoid needing to export
4338 // data fields in some build environments.
4339 static Expr pos_inf_noinline();
4340 static Expr neg_inf_noinline();
4341};
4342
4343/** A class to represent ranges of integers. Can be unbounded above or below, but
4344 * they cannot be empty. */
4345struct ConstantInterval {
4346 /** The lower and upper bound of the interval. They are included
4347 * in the interval. */
4348 int64_t min = 0, max = 0;
4349 bool min_defined = false, max_defined = false;
4350
4351 /* A default-constructed Interval is everything */
4352 ConstantInterval();
4353
4354 /** Construct an interval from a lower and upper bound. */
4355 ConstantInterval(int64_t min, int64_t max);
4356
4357 /** The interval representing everything. */
4358 static ConstantInterval everything();
4359
4360 /** Construct an interval representing a single point. */
4361 static ConstantInterval single_point(int64_t x);
4362
4363 /** Construct intervals bounded above or below. */
4364 static ConstantInterval bounded_below(int64_t min);
4365 static ConstantInterval bounded_above(int64_t max);
4366
4367 /** Is the interval the entire range */
4368 bool is_everything() const;
4369
4370 /** Is the interval just a single value (min == max) */
4371 bool is_single_point() const;
4372
4373 /** Is the interval a particular single value */
4374 bool is_single_point(int64_t x) const;
4375
4376 /** Does the interval have a finite least upper bound */
4377 bool has_upper_bound() const;
4378
4379 /** Does the interval have a finite greatest lower bound */
4380 bool has_lower_bound() const;
4381
4382 /** Does the interval have a finite upper and lower bound */
4383 bool is_bounded() const;
4384
4385 /** Expand the interval to include another Interval */
4386 void include(const ConstantInterval &i);
4387
4388 /** Expand the interval to include a point */
4389 void include(int64_t x);
4390
4391 /** Construct the smallest interval containing two intervals. */
4392 static ConstantInterval make_union(const ConstantInterval &a, const ConstantInterval &b);
4393
4394 /** Equivalent to same_as. Exists so that the autoscheduler can
4395 * compare two map<string, Interval> for equality in order to
4396 * cache computations. */
4397 bool operator==(const ConstantInterval &other) const;
4398};
4399
4400} // namespace Internal
4401} // namespace Halide
4402
4403#endif
4404#ifndef HALIDE_SCOPE_H
4405#define HALIDE_SCOPE_H
4406
4407#include <iostream>
4408#include <map>
4409#include <stack>
4410#include <string>
4411#include <utility>
4412#include <vector>
4413
4414
4415/** \file
4416 * Defines the Scope class, which is used for keeping track of names in a scope while traversing IR
4417 */
4418
4419namespace Halide {
4420namespace Internal {
4421
4422/** A stack which can store one item very efficiently. Using this
4423 * instead of std::stack speeds up Scope substantially. */
4424template<typename T>
4425class SmallStack {
4426private:
4427 T _top;
4428 std::vector<T> _rest;
4429 bool _empty = true;
4430
4431public:
4432 SmallStack() = default;
4433
4434 void pop() {
4435 if (_rest.empty()) {
4436 _empty = true;
4437 _top = T();
4438 } else {
4439 _top = std::move(_rest.back());
4440 _rest.pop_back();
4441 }
4442 }
4443
4444 void push(T t) {
4445 if (!_empty) {
4446 _rest.push_back(std::move(_top));
4447 }
4448 _top = std::move(t);
4449 _empty = false;
4450 }
4451
4452 T top() const {
4453 return _top;
4454 }
4455
4456 T &top_ref() {
4457 return _top;
4458 }
4459
4460 const T &top_ref() const {
4461 return _top;
4462 }
4463
4464 bool empty() const {
4465 return _empty;
4466 }
4467
4468 size_t size() const {
4469 return _empty ? 0 : (_rest.size() + 1);
4470 }
4471};
4472
4473template<>
4474class SmallStack<void> {
4475 // A stack of voids. Voids are all the same, so just record how many voids are in the stack
4476 int counter = 0;
4477
4478public:
4479 void pop() {
4480 counter--;
4481 }
4482 void push() {
4483 counter++;
4484 }
4485 bool empty() const {
4486 return counter == 0;
4487 }
4488};
4489
4490/** A common pattern when traversing Halide IR is that you need to
4491 * keep track of stuff when you find a Let or a LetStmt, and that it
4492 * should hide previous values with the same name until you leave the
4493 * Let or LetStmt nodes This class helps with that. */
4494template<typename T = void>
4495class Scope {
4496private:
4497 std::map<std::string, SmallStack<T>> table;
4498
4499 const Scope<T> *containing_scope = nullptr;
4500
4501public:
4502 Scope() = default;
4503 Scope(Scope &&that) noexcept = default;
4504 Scope &operator=(Scope &&that) noexcept = default;
4505
4506 // Copying a scope object copies a large table full of strings and
4507 // stacks. Bad idea.
4508 Scope(const Scope<T> &) = delete;
4509 Scope<T> &operator=(const Scope<T> &) = delete;
4510
4511 /** Set the parent scope. If lookups fail in this scope, they
4512 * check the containing scope before returning an error. Caller is
4513 * responsible for managing the memory of the containing scope. */
4514 void set_containing_scope(const Scope<T> *s) {
4515 containing_scope = s;
4516 }
4517
4518 /** A const ref to an empty scope. Useful for default function
4519 * arguments, which would otherwise require a copy constructor
4520 * (with llvm in c++98 mode) */
4521 static const Scope<T> &empty_scope() {
4522 static Scope<T> _empty_scope;
4523 return _empty_scope;
4524 }
4525
4526 /** Retrieve the value referred to by a name */
4527 template<typename T2 = T,
4528 typename = typename std::enable_if<!std::is_same<T2, void>::value>::type>
4529 T2 get(const std::string &name) const {
4530 typename std::map<std::string, SmallStack<T>>::const_iterator iter = table.find(name);
4531 if (iter == table.end() || iter->second.empty()) {
4532 if (containing_scope) {
4533 return containing_scope->get(name);
4534 } else {
4535 internal_error << "Name not in Scope: " << name << "\n"
4536 << *this << "\n";
4537 }
4538 }
4539 return iter->second.top();
4540 }
4541
4542 /** Return a reference to an entry. Does not consider the containing scope. */
4543 template<typename T2 = T,
4544 typename = typename std::enable_if<!std::is_same<T2, void>::value>::type>
4545 T2 &ref(const std::string &name) {
4546 typename std::map<std::string, SmallStack<T>>::iterator iter = table.find(name);
4547 if (iter == table.end() || iter->second.empty()) {
4548 internal_error << "Name not in Scope: " << name << "\n"
4549 << *this << "\n";
4550 }
4551 return iter->second.top_ref();
4552 }
4553
4554 /** Tests if a name is in scope */
4555 bool contains(const std::string &name) const {
4556 typename std::map<std::string, SmallStack<T>>::const_iterator iter = table.find(name);
4557 if (iter == table.end() || iter->second.empty()) {
4558 if (containing_scope) {
4559 return containing_scope->contains(name);
4560 } else {
4561 return false;
4562 }
4563 }
4564 return true;
4565 }
4566
4567 /** How many nested definitions of a single name exist? */
4568 size_t count(const std::string &name) const {
4569 auto it = table.find(name);
4570 if (it == table.end()) {
4571 return 0;
4572 } else {
4573 return it->second.size();
4574 }
4575 }
4576
4577 /** Add a new (name, value) pair to the current scope. Hide old
4578 * values that have this name until we pop this name.
4579 */
4580 template<typename T2 = T,
4581 typename = typename std::enable_if<!std::is_same<T2, void>::value>::type>
4582 void push(const std::string &name, T2 &&value) {
4583 table[name].push(std::forward<T2>(value));
4584 }
4585
4586 template<typename T2 = T,
4587 typename = typename std::enable_if<std::is_same<T2, void>::value>::type>
4588 void push(const std::string &name) {
4589 table[name].push();
4590 }
4591
4592 /** A name goes out of scope. Restore whatever its old value
4593 * was (or remove it entirely if there was nothing else of the
4594 * same name in an outer scope) */
4595 void pop(const std::string &name) {
4596 typename std::map<std::string, SmallStack<T>>::iterator iter = table.find(name);
4597 internal_assert(iter != table.end()) << "Name not in Scope: " << name << "\n"
4598 << *this << "\n";
4599 iter->second.pop();
4600 if (iter->second.empty()) {
4601 table.erase(iter);
4602 }
4603 }
4604
4605 /** Iterate through the scope. Does not capture any containing scope. */
4606 class const_iterator {
4607 typename std::map<std::string, SmallStack<T>>::const_iterator iter;
4608
4609 public:
4610 explicit const_iterator(const typename std::map<std::string, SmallStack<T>>::const_iterator &i)
4611 : iter(i) {
4612 }
4613
4614 const_iterator() = default;
4615
4616 bool operator!=(const const_iterator &other) {
4617 return iter != other.iter;
4618 }
4619
4620 void operator++() {
4621 ++iter;
4622 }
4623
4624 const std::string &name() {
4625 return iter->first;
4626 }
4627
4628 const SmallStack<T> &stack() {
4629 return iter->second;
4630 }
4631
4632 template<typename T2 = T,
4633 typename = typename std::enable_if<!std::is_same<T2, void>::value>::type>
4634 const T2 &value() {
4635 return iter->second.top_ref();
4636 }
4637 };
4638
4639 const_iterator cbegin() const {
4640 return const_iterator(table.begin());
4641 }
4642
4643 const_iterator cend() const {
4644 return const_iterator(table.end());
4645 }
4646
4647 void swap(Scope<T> &other) {
4648 table.swap(other.table);
4649 std::swap(containing_scope, other.containing_scope);
4650 }
4651};
4652
4653template<typename T>
4654std::ostream &operator<<(std::ostream &stream, const Scope<T> &s) {
4655 stream << "{\n";
4656 typename Scope<T>::const_iterator iter;
4657 for (iter = s.cbegin(); iter != s.cend(); ++iter) {
4658 stream << " " << iter.name() << "\n";
4659 }
4660 stream << "}";
4661 return stream;
4662}
4663
4664/** Helper class for pushing/popping Scope<> values, to allow
4665 * for early-exit in Visitor/Mutators that preserves correctness.
4666 * Note that this name can be a bit confusing, since there are two "scopes"
4667 * involved here:
4668 * - the Scope object itself
4669 * - the lifetime of this helper object
4670 * The "Scoped" in this class name refers to the latter, as it temporarily binds
4671 * a name within the scope of this helper's lifetime. */
4672template<typename T = void>
4673struct ScopedBinding {
4674 Scope<T> *scope = nullptr;
4675 std::string name;
4676
4677 ScopedBinding() = default;
4678
4679 ScopedBinding(Scope<T> &s, const std::string &n, T value)
4680 : scope(&s), name(n) {
4681 scope->push(name, std::move(value));
4682 }
4683
4684 ScopedBinding(bool condition, Scope<T> &s, const std::string &n, const T &value)
4685 : scope(condition ? &s : nullptr), name(n) {
4686 if (condition) {
4687 scope->push(name, value);
4688 }
4689 }
4690
4691 bool bound() const {
4692 return scope != nullptr;
4693 }
4694
4695 ~ScopedBinding() {
4696 if (scope) {
4697 scope->pop(name);
4698 }
4699 }
4700
4701 // allow move but not copy
4702 ScopedBinding(const ScopedBinding &that) = delete;
4703 ScopedBinding(ScopedBinding &&that) noexcept
4704 : scope(that.scope),
4705 name(std::move(that.name)) {
4706 // The move constructor must null out scope, so we don't try to pop it
4707 that.scope = nullptr;
4708 }
4709
4710 void operator=(const ScopedBinding &that) = delete;
4711 void operator=(ScopedBinding &&that) = delete;
4712};
4713
4714template<>
4715struct ScopedBinding<void> {
4716 Scope<> *scope;
4717 std::string name;
4718 ScopedBinding(Scope<> &s, const std::string &n)
4719 : scope(&s), name(n) {
4720 scope->push(name);
4721 }
4722 ScopedBinding(bool condition, Scope<> &s, const std::string &n)
4723 : scope(condition ? &s : nullptr), name(n) {
4724 if (condition) {
4725 scope->push(name);
4726 }
4727 }
4728 ~ScopedBinding() {
4729 if (scope) {
4730 scope->pop(name);
4731 }
4732 }
4733
4734 // allow move but not copy
4735 ScopedBinding(const ScopedBinding &that) = delete;
4736 ScopedBinding(ScopedBinding &&that) noexcept
4737 : scope(that.scope),
4738 name(std::move(that.name)) {
4739 // The move constructor must null out scope, so we don't try to pop it
4740 that.scope = nullptr;
4741 }
4742
4743 void operator=(const ScopedBinding &that) = delete;
4744 void operator=(ScopedBinding &&that) = delete;
4745};
4746
4747} // namespace Internal
4748} // namespace Halide
4749
4750#endif
4751
4752namespace Halide {
4753namespace Internal {
4754
4755class Function;
4756
4757typedef std::map<std::pair<std::string, int>, Interval> FuncValueBounds;
4758
4759const FuncValueBounds &empty_func_value_bounds();
4760
4761/** Given an expression in some variables, and a map from those
4762 * variables to their bounds (in the form of (minimum possible value,
4763 * maximum possible value)), compute two expressions that give the
4764 * minimum possible value and the maximum possible value of this
4765 * expression. Max or min may be undefined expressions if the value is
4766 * not bounded above or below. If the expression is a vector, also
4767 * takes the bounds across the vector lanes and returns a scalar
4768 * result.
4769 *
4770 * This is for tasks such as deducing the region of a buffer
4771 * loaded by a chunk of code.
4772 */
4773Interval bounds_of_expr_in_scope(const Expr &expr,
4774 const Scope<Interval> &scope,
4775 const FuncValueBounds &func_bounds = empty_func_value_bounds(),
4776 bool const_bound = false);
4777
4778/** Given a varying expression, try to find a constant that is either:
4779 * An upper bound (always greater than or equal to the expression), or
4780 * A lower bound (always less than or equal to the expression)
4781 * If it fails, returns an undefined Expr. */
4782enum class Direction { Upper,
4783 Lower };
4784Expr find_constant_bound(const Expr &e, Direction d,
4785 const Scope<Interval> &scope = Scope<Interval>::empty_scope());
4786
4787/** Find bounds for a varying expression that are either constants or
4788 * +/-inf. */
4789Interval find_constant_bounds(const Expr &e, const Scope<Interval> &scope);
4790
4791/** Represents the bounds of a region of arbitrary dimension. Zero
4792 * dimensions corresponds to a scalar region. */
4793struct Box {
4794 /** The conditions under which this region may be touched. */
4795 Expr used;
4796
4797 /** The bounds if it is touched. */
4798 std::vector<Interval> bounds;
4799
4800 Box() = default;
4801 explicit Box(size_t sz)
4802 : bounds(sz) {
4803 }
4804 explicit Box(const std::vector<Interval> &b)
4805 : bounds(b) {
4806 }
4807
4808 size_t size() const {
4809 return bounds.size();
4810 }
4811 bool empty() const {
4812 return bounds.empty();
4813 }
4814 Interval &operator[](size_t i) {
4815 return bounds[i];
4816 }
4817 const Interval &operator[](size_t i) const {
4818 return bounds[i];
4819 }
4820 void resize(size_t sz) {
4821 bounds.resize(sz);
4822 }
4823 void push_back(const Interval &i) {
4824 bounds.push_back(i);
4825 }
4826
4827 /** Check if the used condition is defined and not trivially true. */
4828 bool maybe_unused() const;
4829
4830 friend std::ostream &operator<<(std::ostream &stream, const Box &b);
4831};
4832
4833/** Expand box a to encompass box b */
4834void merge_boxes(Box &a, const Box &b);
4835
4836/** Test if box a could possibly overlap box b. */
4837bool boxes_overlap(const Box &a, const Box &b);
4838
4839/** The union of two boxes */
4840Box box_union(const Box &a, const Box &b);
4841
4842/** The intersection of two boxes */
4843Box box_intersection(const Box &a, const Box &b);
4844
4845/** Test if box a provably contains box b */
4846bool box_contains(const Box &a, const Box &b);
4847
4848/** Compute rectangular domains large enough to cover all the 'Call's
4849 * to each function that occurs within a given statement or
4850 * expression. This is useful for figuring out what regions of things
4851 * to evaluate. */
4852// @{
4853std::map<std::string, Box> boxes_required(const Expr &e,
4854 const Scope<Interval> &scope = Scope<Interval>::empty_scope(),
4855 const FuncValueBounds &func_bounds = empty_func_value_bounds());
4856std::map<std::string, Box> boxes_required(Stmt s,
4857 const Scope<Interval> &scope = Scope<Interval>::empty_scope(),
4858 const FuncValueBounds &func_bounds = empty_func_value_bounds());
4859// @}
4860
4861/** Compute rectangular domains large enough to cover all the
4862 * 'Provides's to each function that occurs within a given statement
4863 * or expression. */
4864// @{
4865std::map<std::string, Box> boxes_provided(const Expr &e,
4866 const Scope<Interval> &scope = Scope<Interval>::empty_scope(),
4867 const FuncValueBounds &func_bounds = empty_func_value_bounds());
4868std::map<std::string, Box> boxes_provided(Stmt s,
4869 const Scope<Interval> &scope = Scope<Interval>::empty_scope(),
4870 const FuncValueBounds &func_bounds = empty_func_value_bounds());
4871// @}
4872
4873/** Compute rectangular domains large enough to cover all the 'Call's
4874 * and 'Provides's to each function that occurs within a given
4875 * statement or expression. */
4876// @{
4877std::map<std::string, Box> boxes_touched(const Expr &e,
4878 const Scope<Interval> &scope = Scope<Interval>::empty_scope(),
4879 const FuncValueBounds &func_bounds = empty_func_value_bounds());
4880std::map<std::string, Box> boxes_touched(Stmt s,
4881 const Scope<Interval> &scope = Scope<Interval>::empty_scope(),
4882 const FuncValueBounds &func_bounds = empty_func_value_bounds());
4883// @}
4884
4885/** Variants of the above that are only concerned with a single function. */
4886// @{
4887Box box_required(const Expr &e, const std::string &fn,
4888 const Scope<Interval> &scope = Scope<Interval>::empty_scope(),
4889 const FuncValueBounds &func_bounds = empty_func_value_bounds());
4890Box box_required(Stmt s, const std::string &fn,
4891 const Scope<Interval> &scope = Scope<Interval>::empty_scope(),
4892 const FuncValueBounds &func_bounds = empty_func_value_bounds());
4893
4894Box box_provided(const Expr &e, const std::string &fn,
4895 const Scope<Interval> &scope = Scope<Interval>::empty_scope(),
4896 const FuncValueBounds &func_bounds = empty_func_value_bounds());
4897Box box_provided(Stmt s, const std::string &fn,
4898 const Scope<Interval> &scope = Scope<Interval>::empty_scope(),
4899 const FuncValueBounds &func_bounds = empty_func_value_bounds());
4900
4901Box box_touched(const Expr &e, const std::string &fn,
4902 const Scope<Interval> &scope = Scope<Interval>::empty_scope(),
4903 const FuncValueBounds &func_bounds = empty_func_value_bounds());
4904Box box_touched(Stmt s, const std::string &fn,
4905 const Scope<Interval> &scope = Scope<Interval>::empty_scope(),
4906 const FuncValueBounds &func_bounds = empty_func_value_bounds());
4907// @}
4908
4909/** Compute the maximum and minimum possible value for each function
4910 * in an environment. */
4911FuncValueBounds compute_function_value_bounds(const std::vector<std::string> &order,
4912 const std::map<std::string, Function> &env);
4913
4914void bounds_test();
4915
4916} // namespace Internal
4917} // namespace Halide
4918
4919#endif
4920
4921namespace Halide {
4922
4923struct Target;
4924
4925namespace Internal {
4926
4927class Function;
4928
4929/** Insert checks to make sure a statement doesn't read out of bounds
4930 * on inputs or outputs, and that the inputs and outputs conform to
4931 * the format required (e.g. stride.0 must be 1).
4932 */
4933Stmt add_image_checks(const Stmt &s,
4934 const std::vector<Function> &outputs,
4935 const Target &t,
4936 const std::vector<std::string> &order,
4937 const std::map<std::string, Function> &env,
4938 const FuncValueBounds &fb,
4939 bool will_inject_host_copies);
4940
4941} // namespace Internal
4942} // namespace Halide
4943
4944#endif
4945#ifndef HALIDE_INTERNAL_ADD_PARAMETER_CHECKS_H
4946#define HALIDE_INTERNAL_ADD_PARAMETER_CHECKS_H
4947
4948/** \file
4949 *
4950 * Defines the lowering pass that adds the assertions that validate
4951 * scalar parameters.
4952 */
4953#include <vector>
4954
4955#ifndef HALIDE_TARGET_H
4956#define HALIDE_TARGET_H
4957
4958/** \file
4959 * Defines the structure that describes a Halide target.
4960 */
4961
4962#include <bitset>
4963#include <cstdint>
4964#include <string>
4965
4966#ifndef HALIDE_DEVICEAPI_H
4967#define HALIDE_DEVICEAPI_H
4968
4969/** \file
4970 * Defines DeviceAPI.
4971 */
4972
4973#include <string>
4974#include <vector>
4975
4976namespace Halide {
4977
4978/** An enum describing a type of device API. Used by schedules, and in
4979 * the For loop IR node. */
4980enum class DeviceAPI {
4981 None, /// Used to denote for loops that run on the same device as the containing code.
4982 Host,
4983 Default_GPU,
4984 CUDA,
4985 OpenCL,
4986 OpenGLCompute,
4987 Metal,
4988 Hexagon,
4989 HexagonDma,
4990 D3D12Compute,
4991};
4992
4993/** An array containing all the device apis. Useful for iterating
4994 * through them. */
4995const DeviceAPI all_device_apis[] = {DeviceAPI::None,
4996 DeviceAPI::Host,
4997 DeviceAPI::Default_GPU,
4998 DeviceAPI::CUDA,
4999 DeviceAPI::OpenCL,
5000 DeviceAPI::OpenGLCompute,
5001 DeviceAPI::Metal,
5002 DeviceAPI::Hexagon,
5003 DeviceAPI::HexagonDma,
5004 DeviceAPI::D3D12Compute};
5005
5006} // namespace Halide
5007
5008#endif // HALIDE_DEVICEAPI_H
5009
5010namespace Halide {
5011
5012/** A struct representing a target machine and os to generate code for. */
5013struct Target {
5014 /** The operating system used by the target. Determines which
5015 * system calls to generate.
5016 * Corresponds to os_name_map in Target.cpp. */
5017 enum OS {
5018 OSUnknown = 0,
5019 Linux,
5020 Windows,
5021 OSX,
5022 Android,
5023 IOS,
5024 QuRT,
5025 NoOS,
5026 Fuchsia,
5027 WebAssemblyRuntime
5028 } os = OSUnknown;
5029
5030 /** The architecture used by the target. Determines the
5031 * instruction set to use.
5032 * Corresponds to arch_name_map in Target.cpp. */
5033 enum Arch {
5034 ArchUnknown = 0,
5035 X86,
5036 ARM,
5037 MIPS,
5038 Hexagon,
5039 POWERPC,
5040 WebAssembly,
5041 RISCV
5042 } arch = ArchUnknown;
5043
5044 /** The bit-width of the target machine. Must be 0 for unknown, or 32 or 64. */
5045 int bits = 0;
5046
5047 /** Optional features a target can have.
5048 * Corresponds to feature_name_map in Target.cpp.
5049 * See definitions in HalideRuntime.h for full information.
5050 */
5051 enum Feature {
5052 JIT = halide_target_feature_jit,
5053 Debug = halide_target_feature_debug,
5054 NoAsserts = halide_target_feature_no_asserts,
5055 NoBoundsQuery = halide_target_feature_no_bounds_query,
5056 SSE41 = halide_target_feature_sse41,
5057 AVX = halide_target_feature_avx,
5058 AVX2 = halide_target_feature_avx2,
5059 FMA = halide_target_feature_fma,
5060 FMA4 = halide_target_feature_fma4,
5061 F16C = halide_target_feature_f16c,
5062 ARMv7s = halide_target_feature_armv7s,
5063 NoNEON = halide_target_feature_no_neon,
5064 VSX = halide_target_feature_vsx,
5065 POWER_ARCH_2_07 = halide_target_feature_power_arch_2_07,
5066 CUDA = halide_target_feature_cuda,
5067 CUDACapability30 = halide_target_feature_cuda_capability30,
5068 CUDACapability32 = halide_target_feature_cuda_capability32,
5069 CUDACapability35 = halide_target_feature_cuda_capability35,
5070 CUDACapability50 = halide_target_feature_cuda_capability50,
5071 CUDACapability61 = halide_target_feature_cuda_capability61,
5072 CUDACapability70 = halide_target_feature_cuda_capability70,
5073 CUDACapability75 = halide_target_feature_cuda_capability75,
5074 CUDACapability80 = halide_target_feature_cuda_capability80,
5075 OpenCL = halide_target_feature_opencl,
5076 CLDoubles = halide_target_feature_cl_doubles,
5077 CLHalf = halide_target_feature_cl_half,
5078 CLAtomics64 = halide_target_feature_cl_atomic64,
5079 OpenGLCompute = halide_target_feature_openglcompute,
5080 EGL = halide_target_feature_egl,
5081 UserContext = halide_target_feature_user_context,
5082 Matlab = halide_target_feature_matlab,
5083 Profile = halide_target_feature_profile,
5084 NoRuntime = halide_target_feature_no_runtime,
5085 Metal = halide_target_feature_metal,
5086 CPlusPlusMangling = halide_target_feature_c_plus_plus_mangling,
5087 LargeBuffers = halide_target_feature_large_buffers,
5088 HexagonDma = halide_target_feature_hexagon_dma,
5089 HVX_128 = halide_target_feature_hvx_128,
5090 HVX = HVX_128,
5091 HVX_v62 = halide_target_feature_hvx_v62,
5092 HVX_v65 = halide_target_feature_hvx_v65,
5093 HVX_v66 = halide_target_feature_hvx_v66,
5094 HVX_shared_object = halide_target_feature_hvx_use_shared_object,
5095 FuzzFloatStores = halide_target_feature_fuzz_float_stores,
5096 SoftFloatABI = halide_target_feature_soft_float_abi,
5097 MSAN = halide_target_feature_msan,
5098 AVX512 = halide_target_feature_avx512,
5099 AVX512_KNL = halide_target_feature_avx512_knl,
5100 AVX512_Skylake = halide_target_feature_avx512_skylake,
5101 AVX512_Cannonlake = halide_target_feature_avx512_cannonlake,
5102 AVX512_SapphireRapids = halide_target_feature_avx512_sapphirerapids,
5103 TraceLoads = halide_target_feature_trace_loads,
5104 TraceStores = halide_target_feature_trace_stores,
5105 TraceRealizations = halide_target_feature_trace_realizations,
5106 TracePipeline = halide_target_feature_trace_pipeline,
5107 D3D12Compute = halide_target_feature_d3d12compute,
5108 StrictFloat = halide_target_feature_strict_float,
5109 TSAN = halide_target_feature_tsan,
5110 ASAN = halide_target_feature_asan,
5111 CheckUnsafePromises = halide_target_feature_check_unsafe_promises,
5112 EmbedBitcode = halide_target_feature_embed_bitcode,
5113 EnableLLVMLoopOpt = halide_target_feature_enable_llvm_loop_opt,
5114 DisableLLVMLoopOpt = halide_target_feature_disable_llvm_loop_opt,
5115 WasmSimd128 = halide_target_feature_wasm_simd128,
5116 WasmSignExt = halide_target_feature_wasm_signext,
5117 WasmSatFloatToInt = halide_target_feature_wasm_sat_float_to_int,
5118 WasmThreads = halide_target_feature_wasm_threads,
5119 WasmBulkMemory = halide_target_feature_wasm_bulk_memory,
5120 SVE = halide_target_feature_sve,
5121 SVE2 = halide_target_feature_sve2,
5122 ARMDotProd = halide_target_feature_arm_dot_prod,
5123 LLVMLargeCodeModel = halide_llvm_large_code_model,
5124 RVV = halide_target_feature_rvv,
5125 ARMv81a = halide_target_feature_armv81a,
5126 FeatureEnd = halide_target_feature_end
5127 };
5128 Target() = default;
5129 Target(OS o, Arch a, int b, const std::vector<Feature> &initial_features = std::vector<Feature>())
5130 : os(o), arch(a), bits(b) {
5131 for (const auto &f : initial_features) {
5132 set_feature(f);
5133 }
5134 }
5135
5136 /** Given a string of the form used in HL_TARGET
5137 * (e.g. "x86-64-avx"), construct the Target it specifies. Note
5138 * that this always starts with the result of get_host_target(),
5139 * replacing only the parts found in the target string, so if you
5140 * omit (say) an OS specification, the host OS will be used
5141 * instead. An empty string is exactly equivalent to
5142 * get_host_target().
5143 *
5144 * Invalid target strings will fail with a user_error.
5145 */
5146 // @{
5147 explicit Target(const std::string &s);
5148 explicit Target(const char *s);
5149 // @}
5150
5151 /** Check if a target string is valid. */
5152 static bool validate_target_string(const std::string &s);
5153
5154 /** Return true if any of the arch/bits/os fields are "unknown"/0;
5155 return false otherwise. */
5156 bool has_unknowns() const;
5157
5158 void set_feature(Feature f, bool value = true);
5159
5160 void set_features(const std::vector<Feature> &features_to_set, bool value = true);
5161
5162 bool has_feature(Feature f) const;
5163
5164 inline bool has_feature(halide_target_feature_t f) const {
5165 return has_feature((Feature)f);
5166 }
5167
5168 bool features_any_of(const std::vector<Feature> &test_features) const;
5169
5170 bool features_all_of(const std::vector<Feature> &test_features) const;
5171
5172 /** Return a copy of the target with the given feature set.
5173 * This is convenient when enabling certain features (e.g. NoBoundsQuery)
5174 * in an initialization list, where the target to be mutated may be
5175 * a const reference. */
5176 Target with_feature(Feature f) const;
5177
5178 /** Return a copy of the target with the given feature cleared.
5179 * This is convenient when disabling certain features (e.g. NoBoundsQuery)
5180 * in an initialization list, where the target to be mutated may be
5181 * a const reference. */
5182 Target without_feature(Feature f) const;
5183
5184 /** Is a fully feature GPU compute runtime enabled? I.e. is
5185 * Func::gpu_tile and similar going to work? Currently includes
5186 * CUDA, OpenCL, Metal and D3D12Compute. We do not include OpenGL,
5187 * because it is not capable of gpgpu, and is not scheduled via
5188 * Func::gpu_tile.
5189 * TODO: Should OpenGLCompute be included here? */
5190 bool has_gpu_feature() const;
5191
5192 /** Does this target allow using a certain type. Generally all
5193 * types except 64-bit float and int/uint should be supported by
5194 * all backends.
5195 *
5196 * It is likely better to call the version below which takes a DeviceAPI.
5197 */
5198 bool supports_type(const Type &t) const;
5199
5200 /** Does this target allow using a certain type on a certain device.
5201 * This is the prefered version of this routine.
5202 */
5203 bool supports_type(const Type &t, DeviceAPI device) const;
5204
5205 /** Returns whether a particular device API can be used with this
5206 * Target. */
5207 bool supports_device_api(DeviceAPI api) const;
5208
5209 /** If this Target (including all Features) requires a specific DeviceAPI,
5210 * return it. If it doesn't, return DeviceAPI::None. If the Target has
5211 * features with multiple (different) DeviceAPI requirements, the result
5212 * will be an arbitrary DeviceAPI. */
5213 DeviceAPI get_required_device_api() const;
5214
5215 bool operator==(const Target &other) const {
5216 return os == other.os &&
5217 arch == other.arch &&
5218 bits == other.bits &&
5219 features == other.features;
5220 }
5221
5222 bool operator!=(const Target &other) const {
5223 return !(*this == other);
5224 }
5225
5226 /**
5227 * Create a "greatest common denominator" runtime target that is compatible with
5228 * both this target and \p other. Used by generators to conveniently select a suitable
5229 * runtime when linking together multiple functions.
5230 *
5231 * @param other The other target from which we compute the gcd target.
5232 * @param[out] result The gcd target if we return true, otherwise unmodified. Can be the same as *this.
5233 * @return Whether it was possible to find a compatible target (true) or not.
5234 */
5235 bool get_runtime_compatible_target(const Target &other, Target &result);
5236
5237 /** Convert the Target into a string form that can be reconstituted
5238 * by merge_string(), which will always be of the form
5239 *
5240 * arch-bits-os-feature1-feature2...featureN.
5241 *
5242 * Note that is guaranteed that Target(t1.to_string()) == t1,
5243 * but not that Target(s).to_string() == s (since there can be
5244 * multiple strings that parse to the same Target)...
5245 * *unless* t1 contains 'unknown' fields (in which case you'll get a string
5246 * that can't be parsed, which is intentional).
5247 */
5248 std::string to_string() const;
5249
5250 /** Given a data type, return an estimate of the "natural" vector size
5251 * for that data type when compiling for this Target. */
5252 int natural_vector_size(const Halide::Type &t) const;
5253
5254 /** Given a data type, return an estimate of the "natural" vector size
5255 * for that data type when compiling for this Target. */
5256 template<typename data_t>
5257 int natural_vector_size() const {
5258 return natural_vector_size(type_of<data_t>());
5259 }
5260
5261 /** Return true iff 64 bits and has_feature(LargeBuffers). */
5262 bool has_large_buffers() const {
5263 return bits == 64 && has_feature(LargeBuffers);
5264 }
5265
5266 /** Return the maximum buffer size in bytes supported on this
5267 * Target. This is 2^31 - 1 except on 64-bit targets when the LargeBuffers
5268 * feature is enabled, which expands the maximum to 2^63 - 1. */
5269 int64_t maximum_buffer_size() const {
5270 if (has_large_buffers()) {
5271 return (((uint64_t)1) << 63) - 1;
5272 } else {
5273 return (((uint64_t)1) << 31) - 1;
5274 }
5275 }
5276
5277 /** Get the minimum cuda capability found as an integer. Returns
5278 * 20 (our minimum supported cuda compute capability) if no cuda
5279 * features are set. */
5280 int get_cuda_capability_lower_bound() const;
5281
5282 /** Was libHalide compiled with support for this target? */
5283 bool supported() const;
5284
5285 /** Return a bitset of the Featuress set in this Target (set = 1).
5286 * Note that while this happens to be the current internal representation,
5287 * that might not always be the case. */
5288 const std::bitset<FeatureEnd> &get_features_bitset() const {
5289 return features;
5290 }
5291
5292 /** Return the name corresponding to a given Feature, in the form
5293 * used to construct Target strings (e.g., Feature::Debug is "debug" and not "Debug"). */
5294 static std::string feature_to_name(Target::Feature feature);
5295
5296 /** Return the feature corresponding to a given name, in the form
5297 * used to construct Target strings (e.g., Feature::Debug is "debug" and not "Debug").
5298 * If the string is not a known feature name, return FeatureEnd. */
5299 static Target::Feature feature_from_name(const std::string &name);
5300
5301private:
5302 /** A bitmask that stores the active features. */
5303 std::bitset<FeatureEnd> features;
5304};
5305
5306/** Return the target corresponding to the host machine. */
5307Target get_host_target();
5308
5309/** Return the target that Halide will use. If HL_TARGET is set it
5310 * uses that. Otherwise calls \ref get_host_target */
5311Target get_target_from_environment();
5312
5313/** Return the target that Halide will use for jit-compilation. If
5314 * HL_JIT_TARGET is set it uses that. Otherwise calls \ref
5315 * get_host_target. Throws an error if the architecture, bit width,
5316 * and OS of the target do not match the host target, so this is only
5317 * useful for controlling the feature set. */
5318Target get_jit_target_from_environment();
5319
5320/** Get the Target feature corresponding to a DeviceAPI. For device
5321 * apis that do not correspond to any single target feature, returns
5322 * Target::FeatureEnd */
5323Target::Feature target_feature_for_device_api(DeviceAPI api);
5324
5325namespace Internal {
5326
5327void target_test();
5328}
5329
5330} // namespace Halide
5331
5332#endif
5333
5334namespace Halide {
5335namespace Internal {
5336
5337/** Insert checks to make sure that all referenced parameters meet
5338 * their constraints. Also injects any custom requirements provided
5339 * by the user. */
5340Stmt add_parameter_checks(const std::vector<Stmt> &requirements, Stmt s, const Target &t);
5341
5342} // namespace Internal
5343} // namespace Halide
5344
5345#endif
5346#ifndef HALIDE_ALIGN_LOADS_H
5347#define HALIDE_ALIGN_LOADS_H
5348
5349/** \file
5350 * Defines a lowering pass that rewrites unaligned loads into
5351 * sequences of aligned loads.
5352 */
5353
5354namespace Halide {
5355namespace Internal {
5356
5357/** Attempt to rewrite unaligned loads from buffers which are known to
5358 * be aligned to instead load aligned vectors that cover the original
5359 * load, and then slice the original load out of the aligned
5360 * vectors. */
5361Stmt align_loads(const Stmt &s, int alignment);
5362
5363} // namespace Internal
5364} // namespace Halide
5365
5366#endif
5367#ifndef HALIDE_ALLOCATION_BOUNDS_INFERENCE_H
5368#define HALIDE_ALLOCATION_BOUNDS_INFERENCE_H
5369
5370/** \file
5371 * Defines the lowering pass that determines how large internal allocations should be.
5372 */
5373#include <map>
5374#include <string>
5375#include <utility>
5376
5377
5378namespace Halide {
5379namespace Internal {
5380
5381class Function;
5382
5383/** Take a partially statement with Realize nodes in terms of
5384 * variables, and define values for those variables. */
5385Stmt allocation_bounds_inference(Stmt s,
5386 const std::map<std::string, Function> &env,
5387 const std::map<std::pair<std::string, int>, Interval> &func_bounds);
5388} // namespace Internal
5389} // namespace Halide
5390
5391#endif
5392#ifndef APPLY_SPLIT_H
5393#define APPLY_SPLIT_H
5394
5395/** \file
5396 *
5397 * Defines method that returns a list of let stmts, substitutions, and
5398 * predicates to be added given a split schedule.
5399 */
5400
5401#include <map>
5402#include <string>
5403#include <utility>
5404#include <vector>
5405
5406#ifndef HALIDE_SCHEDULE_H
5407#define HALIDE_SCHEDULE_H
5408
5409/** \file
5410 * Defines the internal representation of the schedule for a function
5411 */
5412
5413#include <map>
5414#include <string>
5415#include <utility>
5416#include <vector>
5417
5418#ifndef HALIDE_FUNCTION_PTR_H
5419#define HALIDE_FUNCTION_PTR_H
5420
5421
5422namespace Halide {
5423namespace Internal {
5424
5425/** Functions are allocated in groups for memory management. Each
5426 * group has a ref count associated with it. All within-group
5427 * references must be weak. If there are any references from outside
5428 * the group, at least one must be strong. Within-group references
5429 * may form cycles, but there may not be reference cycles that span
5430 * multiple groups. These rules are not enforced automatically. */
5431struct FunctionGroup;
5432
5433/** The opaque struct describing a Halide function. Wrap it in a
5434 * Function object to access it. */
5435struct FunctionContents;
5436
5437/** A possibly-weak pointer to a Halide function. Take care to follow
5438 * the rules mentioned above. Preserves weakness/strength on copy.
5439 *
5440 * Note that Function objects are always strong pointers to Halide
5441 * functions.
5442 */
5443struct FunctionPtr {
5444 /** A strong and weak pointer to the group. Only one of these
5445 * should be non-zero. */
5446 // @{
5447 IntrusivePtr<FunctionGroup> strong;
5448 FunctionGroup *weak = nullptr;
5449 // @}
5450
5451 /** The index of the function within the group. */
5452 int idx = 0;
5453
5454 /** Get a pointer to the group this Function belongs to. */
5455 FunctionGroup *group() const {
5456 return weak ? weak : strong.get();
5457 }
5458
5459 /** Get the opaque FunctionContents object this pointer refers
5460 * to. Wrap it in a Function to do anything interesting with it. */
5461 // @{
5462 FunctionContents *get() const;
5463
5464 FunctionContents &operator*() const {
5465 return *get();
5466 }
5467
5468 FunctionContents *operator->() const {
5469 return get();
5470 }
5471 // @}
5472
5473 /** Convert from a strong reference to a weak reference. Does
5474 * nothing if the pointer is undefined, or if the reference is
5475 * already weak. */
5476 void weaken() {
5477 weak = group();
5478 strong = nullptr;
5479 }
5480
5481 /** Convert from a weak reference to a strong reference. Does
5482 * nothing if the pointer is undefined, or if the reference is
5483 * already strong. */
5484 void strengthen() {
5485 strong = group();
5486 weak = nullptr;
5487 }
5488
5489 /** Check if the reference is defined. */
5490 bool defined() const {
5491 return weak || strong.defined();
5492 }
5493
5494 /** Check if two FunctionPtrs refer to the same Function. */
5495 bool same_as(const FunctionPtr &other) const {
5496 return idx == other.idx && group() == other.group();
5497 }
5498
5499 /** Pointer comparison, for using FunctionPtrs as keys in maps and
5500 * sets. */
5501 bool operator<(const FunctionPtr &other) const {
5502 return get() < other.get();
5503 }
5504};
5505
5506} // namespace Internal
5507} // namespace Halide
5508
5509#endif
5510#ifndef HALIDE_PARAMETER_H
5511#define HALIDE_PARAMETER_H
5512
5513/** \file
5514 * Defines the internal representation of parameters to halide piplines
5515 */
5516#include <string>
5517
5518
5519namespace Halide {
5520
5521struct ArgumentEstimates;
5522template<typename T>
5523class Buffer;
5524struct Expr;
5525struct Type;
5526enum class MemoryType;
5527
5528namespace Internal {
5529
5530struct ParameterContents;
5531
5532/** A reference-counted handle to a parameter to a halide
5533 * pipeline. May be a scalar parameter or a buffer */
5534class Parameter {
5535 void check_defined() const;
5536 void check_is_buffer() const;
5537 void check_is_scalar() const;
5538 void check_dim_ok(int dim) const;
5539 void check_type(const Type &t) const;
5540
5541protected:
5542 IntrusivePtr<ParameterContents> contents;
5543
5544public:
5545 /** Construct a new undefined handle */
5546 Parameter() = default;
5547
5548 /** Construct a new parameter of the given type. If the second
5549 * argument is true, this is a buffer parameter of the given
5550 * dimensionality, otherwise, it is a scalar parameter (and the
5551 * dimensionality should be zero). The parameter will be given a
5552 * unique auto-generated name. */
5553 Parameter(const Type &t, bool is_buffer, int dimensions);
5554
5555 /** Construct a new parameter of the given type with name given by
5556 * the third argument. If the second argument is true, this is a
5557 * buffer parameter, otherwise, it is a scalar parameter. The
5558 * third argument gives the dimensionality of the buffer
5559 * parameter. It should be zero for scalar parameters. If the
5560 * fifth argument is true, the the name being passed in was
5561 * explicitly specified (as opposed to autogenerated). */
5562 Parameter(const Type &t, bool is_buffer, int dimensions, const std::string &name);
5563
5564 Parameter(const Parameter &) = default;
5565 Parameter &operator=(const Parameter &) = default;
5566 Parameter(Parameter &&) = default;
5567 Parameter &operator=(Parameter &&) = default;
5568
5569 /** Get the type of this parameter */
5570 Type type() const;
5571
5572 /** Get the dimensionality of this parameter. Zero for scalars. */
5573 int dimensions() const;
5574
5575 /** Get the name of this parameter */
5576 const std::string &name() const;
5577
5578 /** Does this parameter refer to a buffer/image? */
5579 bool is_buffer() const;
5580
5581 /** If the parameter is a scalar parameter, get its currently
5582 * bound value. Only relevant when jitting */
5583 template<typename T>
5584 HALIDE_NO_USER_CODE_INLINE T scalar() const {
5585 check_type(type_of<T>());
5586 return *((const T *)(scalar_address()));
5587 }
5588
5589 /** This returns the current value of scalar<type()>()
5590 * as an Expr. */
5591 Expr scalar_expr() const;
5592
5593 /** If the parameter is a scalar parameter, set its current
5594 * value. Only relevant when jitting */
5595 template<typename T>
5596 HALIDE_NO_USER_CODE_INLINE void set_scalar(T val) {
5597 check_type(type_of<T>());
5598 *((T *)(scalar_address())) = val;
5599 }
5600
5601 /** If the parameter is a scalar parameter, set its current
5602 * value. Only relevant when jitting */
5603 HALIDE_NO_USER_CODE_INLINE void set_scalar(const Type &val_type, halide_scalar_value_t val) {
5604 check_type(val_type);
5605 memcpy(scalar_address(), &val, val_type.bytes());
5606 }
5607
5608 /** If the parameter is a buffer parameter, get its currently
5609 * bound buffer. Only relevant when jitting */
5610 Buffer<void> buffer() const;
5611
5612 /** Get the raw currently-bound buffer. null if unbound */
5613 const halide_buffer_t *raw_buffer() const;
5614
5615 /** If the parameter is a buffer parameter, set its current
5616 * value. Only relevant when jitting */
5617 void set_buffer(const Buffer<void> &b);
5618
5619 /** Get the pointer to the current value of the scalar
5620 * parameter. For a given parameter, this address will never
5621 * change. Only relevant when jitting. */
5622 void *scalar_address() const;
5623
5624 /** Tests if this handle is the same as another handle */
5625 bool same_as(const Parameter &other) const;
5626
5627 /** Tests if this handle is non-nullptr */
5628 bool defined() const;
5629
5630 /** Get and set constraints for the min, extent, stride, and estimates on
5631 * the min/extent. */
5632 //@{
5633 void set_min_constraint(int dim, Expr e);
5634 void set_extent_constraint(int dim, Expr e);
5635 void set_stride_constraint(int dim, Expr e);
5636 void set_min_constraint_estimate(int dim, Expr min);
5637 void set_extent_constraint_estimate(int dim, Expr extent);
5638 void set_host_alignment(int bytes);
5639 Expr min_constraint(int dim) const;
5640 Expr extent_constraint(int dim) const;
5641 Expr stride_constraint(int dim) const;
5642 Expr min_constraint_estimate(int dim) const;
5643 Expr extent_constraint_estimate(int dim) const;
5644 int host_alignment() const;
5645 //@}
5646
5647 /** Get and set constraints for scalar parameters. These are used
5648 * directly by Param, so they must be exported. */
5649 // @{
5650 void set_min_value(const Expr &e);
5651 Expr min_value() const;
5652 void set_max_value(const Expr &e);
5653 Expr max_value() const;
5654 void set_estimate(Expr e);
5655 Expr estimate() const;
5656 // @}
5657
5658 /** Get and set the default values for scalar parameters. At present, these
5659 * are used only to emit the default values in the metadata. There isn't
5660 * yet a public API in Param<> for them (this is used internally by the
5661 * Generator code). */
5662 // @{
5663 void set_default_value(const Expr &e);
5664 Expr default_value() const;
5665 // @}
5666
5667 /** Order Parameters by their IntrusivePtr so they can be used
5668 * to index maps. */
5669 bool operator<(const Parameter &other) const {
5670 return contents < other.contents;
5671 }
5672
5673 /** Get the ArgumentEstimates appropriate for this Parameter. */
5674 ArgumentEstimates get_argument_estimates() const;
5675
5676 void store_in(MemoryType memory_type);
5677 MemoryType memory_type() const;
5678};
5679
5680/** Validate arguments to a call to a func, image or imageparam. */
5681void check_call_arg_types(const std::string &name, std::vector<Expr> *args, int dims);
5682
5683} // namespace Internal
5684} // namespace Halide
5685
5686#endif
5687#ifndef HALIDE_PREFETCH_DIRECTIVE_H
5688#define HALIDE_PREFETCH_DIRECTIVE_H
5689
5690/** \file
5691 * Defines the PrefetchDirective struct
5692 */
5693
5694#include <string>
5695
5696
5697namespace Halide {
5698
5699/** Different ways to handle accesses outside the original extents in a prefetch. */
5700enum class PrefetchBoundStrategy {
5701 /** Clamp the prefetched exprs by intersecting the prefetched region with
5702 * the original extents. This may make the exprs of the prefetched region
5703 * more complicated. */
5704 Clamp,
5705
5706 /** Guard the prefetch with if-guards that ignores the prefetch if
5707 * any of the prefetched region ever goes beyond the original extents
5708 * (i.e. all or nothing). */
5709 GuardWithIf,
5710
5711 /** Leave the prefetched exprs as are (no if-guards around the prefetch
5712 * and no intersecting with the original extents). This makes the prefetch
5713 * exprs simpler but this may cause prefetching of region outside the original
5714 * extents. This is good if prefetch won't fault when accessing region
5715 * outside the original extents. */
5716 NonFaulting
5717};
5718
5719namespace Internal {
5720
5721struct PrefetchDirective {
5722 std::string name;
5723 std::string var;
5724 Expr offset;
5725 PrefetchBoundStrategy strategy;
5726 // If it's a prefetch load from an image parameter, this points to that.
5727 Parameter param;
5728};
5729
5730} // namespace Internal
5731
5732} // namespace Halide
5733
5734#endif // HALIDE_PREFETCH_DIRECTIVE_H
5735
5736namespace Halide {
5737
5738class Func;
5739struct VarOrRVar;
5740
5741namespace Internal {
5742class Function;
5743struct FunctionContents;
5744struct LoopLevelContents;
5745} // namespace Internal
5746
5747/** Different ways to handle a tail case in a split when the
5748 * factor does not provably divide the extent. */
5749enum class TailStrategy {
5750 /** Round up the extent to be a multiple of the split
5751 * factor. Not legal for RVars, as it would change the meaning
5752 * of the algorithm. Pros: generates the simplest, fastest
5753 * code. Cons: if used on a stage that reads from the input or
5754 * writes to the output, constrains the input or output size
5755 * to be a multiple of the split factor. */
5756 RoundUp,
5757
5758 /** Guard the inner loop with an if statement that prevents
5759 * evaluation beyond the original extent. Always legal. The if
5760 * statement is treated like a boundary condition, and
5761 * factored out into a loop epilogue if possible. Pros: no
5762 * redundant re-evaluation; does not constrain input our
5763 * output sizes. Cons: increases code size due to separate
5764 * tail-case handling; vectorization will scalarize in the tail
5765 * case to handle the if statement. */
5766 GuardWithIf,
5767
5768 /** Guard the inner loop with an if statement that prevents
5769 * evaluation beyond the original extent, with a hint that the
5770 * if statement should be implemented with predicated operations.
5771 * Always legal. The if statement is treated like a boundary
5772 * condition, and factored out into a loop epilogue if possible.
5773 * Pros: no redundant re-evaluation; does not constrain input our
5774 * output sizes. Cons: increases code size due to separate
5775 * tail-case handling. */
5776 Predicate,
5777
5778 /** Prevent evaluation beyond the original extent by shifting
5779 * the tail case inwards, re-evaluating some points near the
5780 * end. Only legal for pure variables in pure definitions. If
5781 * the inner loop is very simple, the tail case is treated
5782 * like a boundary condition and factored out into an
5783 * epilogue.
5784 *
5785 * This is a good trade-off between several factors. Like
5786 * RoundUp, it supports vectorization well, because the inner
5787 * loop is always a fixed size with no data-dependent
5788 * branching. It increases code size slightly for inner loops
5789 * due to the epilogue handling, but not for outer loops
5790 * (e.g. loops over tiles). If used on a stage that reads from
5791 * an input or writes to an output, this stategy only requires
5792 * that the input/output extent be at least the split factor,
5793 * instead of a multiple of the split factor as with RoundUp. */
5794 ShiftInwards,
5795
5796 /** For pure definitions use ShiftInwards. For pure vars in
5797 * update definitions use RoundUp. For RVars in update
5798 * definitions use GuardWithIf. */
5799 Auto
5800};
5801
5802/** Different ways to handle the case when the start/end of the loops of stages
5803 * computed with (fused) are not aligned. */
5804enum class LoopAlignStrategy {
5805 /** Shift the start of the fused loops to align. */
5806 AlignStart,
5807
5808 /** Shift the end of the fused loops to align. */
5809 AlignEnd,
5810
5811 /** compute_with will make no attempt to align the start/end of the
5812 * fused loops. */
5813 NoAlign,
5814
5815 /** By default, LoopAlignStrategy is set to NoAlign. */
5816 Auto
5817};
5818
5819/** A reference to a site in a Halide statement at the top of the
5820 * body of a particular for loop. Evaluating a region of a halide
5821 * function is done by generating a loop nest that spans its
5822 * dimensions. We schedule the inputs to that function by
5823 * recursively injecting realizations for them at particular sites
5824 * in this loop nest. A LoopLevel identifies such a site. The site
5825 * can either be a loop nest within all stages of a function
5826 * or it can refer to a loop nest within a particular function's
5827 * stage (initial definition or updates).
5828 *
5829 * Note that a LoopLevel is essentially a pointer to an underlying value;
5830 * all copies of a LoopLevel refer to the same site, so mutating one copy
5831 * (via the set() method) will effectively mutate all copies:
5832 \code
5833 Func f;
5834 Var x;
5835 LoopLevel a(f, x);
5836 // Both a and b refer to LoopLevel(f, x)
5837 LoopLevel b = a;
5838 // Now both a and b refer to LoopLevel::root()
5839 a.set(LoopLevel::root());
5840 \endcode
5841 * This is quite useful when splitting Halide code into utility libraries, as it allows
5842 * a library to schedule code according to a caller's specifications, even if the caller
5843 * hasn't fully defined its pipeline yet:
5844 \code
5845 Func demosaic(Func input,
5846 LoopLevel intermed_compute_at,
5847 LoopLevel intermed_store_at,
5848 LoopLevel output_compute_at) {
5849 Func intermed = ...;
5850 Func output = ...;
5851 intermed.compute_at(intermed_compute_at).store_at(intermed_store_at);
5852 output.compute_at(output_compute_at);
5853 return output;
5854 }
5855
5856 void process() {
5857 // Note that these LoopLevels are all undefined when we pass them to demosaic()
5858 LoopLevel intermed_compute_at, intermed_store_at, output_compute_at;
5859 Func input = ...;
5860 Func demosaiced = demosaic(input, intermed_compute_at, intermed_store_at, output_compute_at);
5861 Func output = ...;
5862
5863 // We need to ensure all LoopLevels have a well-defined value prior to lowering:
5864 intermed_compute_at.set(LoopLevel(output, y));
5865 intermed_store_at.set(LoopLevel(output, y));
5866 output_compute_at.set(LoopLevel(output, x));
5867 }
5868 \endcode
5869 */
5870class LoopLevel {
5871 Internal::IntrusivePtr<Internal::LoopLevelContents> contents;
5872
5873 explicit LoopLevel(Internal::IntrusivePtr<Internal::LoopLevelContents> c)
5874 : contents(std::move(c)) {
5875 }
5876 LoopLevel(const std::string &func_name, const std::string &var_name,
5877 bool is_rvar, int stage_index, bool locked = false);
5878
5879public:
5880 /** Return the index of the function stage associated with this loop level.
5881 * Asserts if undefined */
5882 int stage_index() const;
5883
5884 /** Identify the loop nest corresponding to some dimension of some function */
5885 // @{
5886 LoopLevel(const Internal::Function &f, const VarOrRVar &v, int stage_index = -1);
5887 LoopLevel(const Func &f, const VarOrRVar &v, int stage_index = -1);
5888 // @}
5889
5890 /** Construct an undefined LoopLevel. Calling any method on an undefined
5891 * LoopLevel (other than set()) will assert. */
5892 LoopLevel();
5893
5894 /** Construct a special LoopLevel value that implies
5895 * that a function should be inlined away. */
5896 static LoopLevel inlined();
5897
5898 /** Construct a special LoopLevel value which represents the
5899 * location outside of all for loops. */
5900 static LoopLevel root();
5901
5902 /** Mutate our contents to match the contents of 'other'. */
5903 void set(const LoopLevel &other);
5904
5905 // All the public methods below this point are meant only for internal
5906 // use by Halide, rather than user code; hence, they are deliberately
5907 // documented with plain comments (rather than Doxygen) to avoid being
5908 // present in user documentation.
5909
5910 // Lock this LoopLevel.
5911 LoopLevel &lock();
5912
5913 // Return the Func name. Asserts if the LoopLevel is_root() or is_inlined() or !defined().
5914 std::string func() const;
5915
5916 // Return the VarOrRVar. Asserts if the LoopLevel is_root() or is_inlined() or !defined().
5917 VarOrRVar var() const;
5918
5919 // Return true iff the LoopLevel is defined. (Only LoopLevels created
5920 // with the default ctor are undefined.)
5921 bool defined() const;
5922
5923 // Test if a loop level corresponds to inlining the function.
5924 bool is_inlined() const;
5925
5926 // Test if a loop level is 'root', which describes the site
5927 // outside of all for loops.
5928 bool is_root() const;
5929
5930 // Return a string of the form func.var -- note that this is safe
5931 // to call for root or inline LoopLevels, but asserts if !defined().
5932 std::string to_string() const;
5933
5934 // Compare this loop level against the variable name of a for
5935 // loop, to see if this loop level refers to the site
5936 // immediately inside this loop. Asserts if !defined().
5937 bool match(const std::string &loop) const;
5938
5939 bool match(const LoopLevel &other) const;
5940
5941 // Check if two loop levels are exactly the same.
5942 bool operator==(const LoopLevel &other) const;
5943
5944 bool operator!=(const LoopLevel &other) const {
5945 return !(*this == other);
5946 }
5947
5948private:
5949 void check_defined() const;
5950 void check_locked() const;
5951 void check_defined_and_locked() const;
5952};
5953
5954struct FuseLoopLevel {
5955 LoopLevel level;
5956 /** Contains alignment strategies for the fused dimensions (indexed by the
5957 * dimension name). If not in the map, use the default alignment strategy
5958 * to align the fused dimension (see \ref LoopAlignStrategy::Auto).
5959 */
5960 std::map<std::string, LoopAlignStrategy> align;
5961
5962 FuseLoopLevel()
5963 : level(LoopLevel::inlined().lock()) {
5964 }
5965 FuseLoopLevel(const LoopLevel &level, const std::map<std::string, LoopAlignStrategy> &align)
5966 : level(level), align(align) {
5967 }
5968};
5969
5970namespace Internal {
5971
5972class IRMutator;
5973struct ReductionVariable;
5974
5975struct Split {
5976 std::string old_var, outer, inner;
5977 Expr factor;
5978 bool exact; // Is it required that the factor divides the extent
5979 // of the old var. True for splits of RVars. Forces
5980 // tail strategy to be GuardWithIf.
5981 TailStrategy tail;
5982
5983 enum SplitType { SplitVar = 0,
5984 RenameVar,
5985 FuseVars,
5986 PurifyRVar };
5987
5988 // If split_type is Rename, then this is just a renaming of the
5989 // old_var to the outer and not a split. The inner var should
5990 // be ignored, and factor should be one. Renames are kept in
5991 // the same list as splits so that ordering between them is
5992 // respected.
5993
5994 // If split type is Purify, this replaces the old_var RVar to
5995 // the outer Var. The inner var should be ignored, and factor
5996 // should be one.
5997
5998 // If split_type is Fuse, then this does the opposite of a
5999 // split, it joins the outer and inner into the old_var.
6000 SplitType split_type;
6001
6002 bool is_rename() const {
6003 return split_type == RenameVar;
6004 }
6005 bool is_split() const {
6006 return split_type == SplitVar;
6007 }
6008 bool is_fuse() const {
6009 return split_type == FuseVars;
6010 }
6011 bool is_purify() const {
6012 return split_type == PurifyRVar;
6013 }
6014};
6015
6016/** Each Dim below has a dim_type, which tells you what
6017 * transformations are legal on it. When you combine two Dims of
6018 * distinct DimTypes (e.g. with Stage::fuse), the combined result has
6019 * the greater enum value of the two types. */
6020enum class DimType {
6021 /** This dim originated from a Var. You can evaluate a Func at
6022 * distinct values of this Var in any order over an interval
6023 * that's at least as large as the interval required. In pure
6024 * definitions you can even redundantly re-evaluate points. */
6025 PureVar = 0,
6026
6027 /** The dim originated from an RVar. You can evaluate a Func at
6028 * distinct values of this RVar in any order (including in
6029 * parallel) over exactly the interval specified in the
6030 * RDom. PureRVars can also be reordered arbitrarily in the dims
6031 * list, as there are no data hazards between the evaluation of
6032 * the Func at distinct values of the RVar.
6033 *
6034 * The most common case where an RVar is considered pure is RVars
6035 * that are used in a way which obeys all the syntactic
6036 * constraints that a Var does, e.g:
6037 *
6038 \code
6039 RDom r(0, 100);
6040 f(r.x) = f(r.x) + 5;
6041 \endcode
6042 *
6043 * Other cases where RVars are pure are where the sites being
6044 * written to by the Func evaluated at one value of the RVar
6045 * couldn't possibly collide with the sites being written or read
6046 * by the Func at a distinct value of the RVar. For example, r.x
6047 * is pure in the following three definitions:
6048 *
6049 \code
6050
6051 // This definition writes to even coordinates and reads from the
6052 // same site (which no other value of r.x is writing to) and odd
6053 // sites (which no other value of r.x is writing to):
6054 f(2*r.x) = max(f(2*r.x), f(2*r.x + 7));
6055
6056 // This definition writes to scanline zero and reads from the the
6057 // same site and scanline one:
6058 f(r.x, 0) += f(r.x, 1);
6059
6060 // This definition reads and writes over non-overlapping ranges:
6061 f(r.x + 100) += f(r.x);
6062 \endcode
6063 *
6064 * To give two counterexamples, r.x is not pure in the following
6065 * definitions:
6066 *
6067 \code
6068 // The same site is written by distinct values of the RVar
6069 // (write-after-write hazard):
6070 f(r.x / 2) += f(r.x);
6071
6072 // One value of r.x reads from a site that another value of r.x
6073 // is writing to (read-after-write hazard):
6074 f(r.x) += f(r.x + 1);
6075 \endcode
6076 */
6077 PureRVar,
6078
6079 /** The dim originated from an RVar. You must evaluate a Func at
6080 * distinct values of this RVar in increasing order over precisely
6081 * the interval specified in the RDom. ImpureRVars may not be
6082 * reordered with respect to other ImpureRVars.
6083 *
6084 * All RVars are impure by default. Those for which we can prove
6085 * no data hazards exist get promoted to PureRVar. There are two
6086 * instances in which ImpureRVars may be parallelized or reordered
6087 * even in the presence of hazards:
6088 *
6089 * 1) In the case of an update definition that has been proven to be
6090 * an associative and commutative reduction, reordering of
6091 * ImpureRVars is allowed, and parallelizing them is allowed if
6092 * the update has been made atomic.
6093 *
6094 * 2) ImpureRVars can also be reordered and parallelized if
6095 * Func::allow_race_conditions() has been set. This is the escape
6096 * hatch for when there are no hazards but the checks above failed
6097 * to prove that (RDom::where can encode arbitrary facts about
6098 * non-linear integer arithmetic, which is undecidable), or for
6099 * when you don't actually care about the non-determinism
6100 * introduced by data hazards (e.g. in the algorithm HOGWILD!).
6101 */
6102 ImpureRVar,
6103};
6104
6105/** The Dim struct represents one loop in the schedule's
6106 * representation of a loop nest. */
6107struct Dim {
6108 /** Name of the loop variable */
6109 std::string var;
6110
6111 /** How are the loop values traversed (e.g. unrolled, vectorized, parallel) */
6112 ForType for_type;
6113
6114 /** On what device does the body of the loop execute (e.g. Host, GPU, Hexagon) */
6115 DeviceAPI device_api;
6116
6117 /** The DimType tells us what transformations are legal on this
6118 * loop (see the DimType enum above). */
6119 DimType dim_type;
6120
6121 /** Can this loop be evaluated in any order (including in
6122 * parallel)? Equivalently, are there no data hazards between
6123 * evaluations of the Func at distinct values of this var? */
6124 bool is_pure() const {
6125 return (dim_type == DimType::PureVar) || (dim_type == DimType::PureRVar);
6126 }
6127
6128 /** Did this loop originate from an RVar (in which case the bounds
6129 * of the loops are algorithmically meaningful)? */
6130 bool is_rvar() const {
6131 return (dim_type == DimType::PureRVar) || (dim_type == DimType::ImpureRVar);
6132 }
6133
6134 /** Could multiple iterations of this loop happen at the same
6135 * time, with reads and writes interleaved in arbitrary ways
6136 * according to the memory model of the underlying compiler and
6137 * machine? */
6138 bool is_unordered_parallel() const {
6139 return Halide::Internal::is_unordered_parallel(for_type);
6140 }
6141
6142 /** Could multiple iterations of this loop happen at the same
6143 * time? Vectorized and GPULanes loop types are parallel but not
6144 * unordered, because the loop iterations proceed together in
6145 * lockstep with some well-defined outcome if there are hazards. */
6146 bool is_parallel() const {
6147 return Halide::Internal::is_parallel(for_type);
6148 }
6149};
6150
6151/** A bound on a loop, typically from Func::bound */
6152struct Bound {
6153 /** The loop var being bounded */
6154 std::string var;
6155
6156 /** Declared min and extent of the loop. min may be undefined if
6157 * Func::bound_extent was used. */
6158 Expr min, extent;
6159
6160 /** If defined, the number of iterations will be a multiple of
6161 * "modulus", and the first iteration will be at a value congruent
6162 * to "remainder" modulo "modulus". Set by Func::align_bounds and
6163 * Func::align_extent. */
6164 Expr modulus, remainder;
6165};
6166
6167/** Properties of one axis of the storage of a Func */
6168struct StorageDim {
6169 /** The var in the pure definition corresponding to this axis */
6170 std::string var;
6171
6172 /** The bounds allocated (not computed) must be a multiple of
6173 * "alignment". Set by Func::align_storage. */
6174 Expr alignment;
6175
6176 /** If the Func is explicitly folded along this axis (with
6177 * Func::fold_storage) this gives the extent of the circular
6178 * buffer used, and whether it is used in increasing order
6179 * (fold_forward = true) or decreasing order (fold_forward =
6180 * false). */
6181 Expr fold_factor;
6182 bool fold_forward;
6183};
6184
6185/** This represents two stages with fused loop nests from outermost to
6186 * a specific loop level. The loops to compute func_1(stage_1) are
6187 * fused with the loops to compute func_2(stage_2) from outermost to
6188 * loop level var_name and the computation from stage_1 of func_1
6189 * occurs first.
6190 */
6191struct FusedPair {
6192 std::string func_1;
6193 std::string func_2;
6194 size_t stage_1;
6195 size_t stage_2;
6196 std::string var_name;
6197
6198 FusedPair() = default;
6199 FusedPair(const std::string &f1, size_t s1, const std::string &f2,
6200 size_t s2, const std::string &var)
6201 : func_1(f1), func_2(f2), stage_1(s1), stage_2(s2), var_name(var) {
6202 }
6203
6204 bool operator==(const FusedPair &other) const {
6205 return (func_1 == other.func_1) && (func_2 == other.func_2) &&
6206 (stage_1 == other.stage_1) && (stage_2 == other.stage_2) &&
6207 (var_name == other.var_name);
6208 }
6209 bool operator<(const FusedPair &other) const {
6210 if (func_1 != other.func_1) {
6211 return func_1 < other.func_1;
6212 }
6213 if (func_2 != other.func_2) {
6214 return func_2 < other.func_2;
6215 }
6216 if (var_name != other.var_name) {
6217 return var_name < other.var_name;
6218 }
6219 if (stage_1 != other.stage_1) {
6220 return stage_1 < other.stage_1;
6221 }
6222 return stage_2 < other.stage_2;
6223 }
6224};
6225
6226struct FuncScheduleContents;
6227struct StageScheduleContents;
6228struct FunctionContents;
6229
6230/** A schedule for a Function of a Halide pipeline. This schedule is
6231 * applied to all stages of the Function. Right now this interface is
6232 * basically a struct, offering mutable access to its innards.
6233 * In the future it may become more encapsulated. */
6234class FuncSchedule {
6235 IntrusivePtr<FuncScheduleContents> contents;
6236
6237public:
6238 FuncSchedule(IntrusivePtr<FuncScheduleContents> c)
6239 : contents(std::move(c)) {
6240 }
6241 FuncSchedule(const FuncSchedule &other) = default;
6242 FuncSchedule();
6243
6244 /** Return a deep copy of this FuncSchedule. It recursively deep copies all
6245 * called functions, schedules, specializations, and reduction domains. This
6246 * method takes a map of <old FunctionContents, deep-copied version> as input
6247 * and would use the deep-copied FunctionContents from the map if exists
6248 * instead of creating a new deep-copy to avoid creating deep-copies of the
6249 * same FunctionContents multiple times.
6250 */
6251 FuncSchedule deep_copy(
6252 std::map<FunctionPtr, FunctionPtr> &copied_map) const;
6253
6254 /** This flag is set to true if the schedule is memoized. */
6255 // @{
6256 bool &memoized();
6257 bool memoized() const;
6258 // @}
6259
6260 /** This flag is set to true if the schedule is memoized and has an attached
6261 * eviction key. */
6262 // @{
6263 Expr &memoize_eviction_key();
6264 Expr memoize_eviction_key() const;
6265 // @}
6266
6267 /** Is the production of this Function done asynchronously */
6268 bool &async();
6269 bool async() const;
6270
6271 /** The list and order of dimensions used to store this
6272 * function. The first dimension in the vector corresponds to the
6273 * innermost dimension for storage (i.e. which dimension is
6274 * tightly packed in memory) */
6275 // @{
6276 const std::vector<StorageDim> &storage_dims() const;
6277 std::vector<StorageDim> &storage_dims();
6278 // @}
6279
6280 /** The memory type (heap/stack/shared/etc) used to back this Func. */
6281 // @{
6282 MemoryType memory_type() const;
6283 MemoryType &memory_type();
6284 // @}
6285
6286 /** You may explicitly bound some of the dimensions of a function,
6287 * or constrain them to lie on multiples of a given factor. See
6288 * \ref Func::bound and \ref Func::align_bounds and \ref Func::align_extent. */
6289 // @{
6290 const std::vector<Bound> &bounds() const;
6291 std::vector<Bound> &bounds();
6292 // @}
6293
6294 /** You may explicitly specify an estimate of some of the function
6295 * dimensions. See \ref Func::set_estimate */
6296 // @{
6297 const std::vector<Bound> &estimates() const;
6298 std::vector<Bound> &estimates();
6299 // @}
6300
6301 /** Mark calls of a function by 'f' to be replaced with its identity
6302 * wrapper or clone during the lowering stage. If the string 'f' is empty,
6303 * it means replace all calls to the function by all other functions
6304 * (excluding itself) in the pipeline with the global identity wrapper.
6305 * See \ref Func::in and \ref Func::clone_in for more details. */
6306 // @{
6307 const std::map<std::string, Internal::FunctionPtr> &wrappers() const;
6308 std::map<std::string, Internal::FunctionPtr> &wrappers();
6309 void add_wrapper(const std::string &f,
6310 const Internal::FunctionPtr &wrapper);
6311 // @}
6312
6313 /** At what sites should we inject the allocation and the
6314 * computation of this function? The store_level must be outside
6315 * of or equal to the compute_level. If the compute_level is
6316 * inline, the store_level is meaningless. See \ref Func::store_at
6317 * and \ref Func::compute_at */
6318 // @{
6319 const LoopLevel &store_level() const;
6320 const LoopLevel &compute_level() const;
6321 LoopLevel &store_level();
6322 LoopLevel &compute_level();
6323 // @}
6324
6325 /** Pass an IRVisitor through to all Exprs referenced in the
6326 * Schedule. */
6327 void accept(IRVisitor *) const;
6328
6329 /** Pass an IRMutator through to all Exprs referenced in the
6330 * Schedule. */
6331 void mutate(IRMutator *);
6332};
6333
6334/** A schedule for a single stage of a Halide pipeline. Right now this
6335 * interface is basically a struct, offering mutable access to its
6336 * innards. In the future it may become more encapsulated. */
6337class StageSchedule {
6338 IntrusivePtr<StageScheduleContents> contents;
6339
6340public:
6341 StageSchedule(IntrusivePtr<StageScheduleContents> c)
6342 : contents(std::move(c)) {
6343 }
6344 StageSchedule(const StageSchedule &other) = default;
6345 StageSchedule();
6346
6347 /** Return a copy of this StageSchedule. */
6348 StageSchedule get_copy() const;
6349
6350 /** This flag is set to true if the dims list has been manipulated
6351 * by the user (or if a ScheduleHandle was created that could have
6352 * been used to manipulate it). It controls the warning that
6353 * occurs if you schedule the vars of the pure step but not the
6354 * update steps. */
6355 // @{
6356 bool &touched();
6357 bool touched() const;
6358 // @}
6359
6360 /** RVars of reduction domain associated with this schedule if there is any. */
6361 // @{
6362 const std::vector<ReductionVariable> &rvars() const;
6363 std::vector<ReductionVariable> &rvars();
6364 // @}
6365
6366 /** The traversal of the domain of a function can have some of its
6367 * dimensions split into sub-dimensions. See \ref Func::split */
6368 // @{
6369 const std::vector<Split> &splits() const;
6370 std::vector<Split> &splits();
6371 // @}
6372
6373 /** The list and ordering of dimensions used to evaluate this
6374 * function, after all splits have taken place. The first
6375 * dimension in the vector corresponds to the innermost for loop,
6376 * and the last is the outermost. Also specifies what type of for
6377 * loop to use for each dimension. Does not specify the bounds on
6378 * each dimension. These get inferred from how the function is
6379 * used, what the splits are, and any optional bounds in the list below. */
6380 // @{
6381 const std::vector<Dim> &dims() const;
6382 std::vector<Dim> &dims();
6383 // @}
6384
6385 /** You may perform prefetching in some of the dimensions of a
6386 * function. See \ref Func::prefetch */
6387 // @{
6388 const std::vector<PrefetchDirective> &prefetches() const;
6389 std::vector<PrefetchDirective> &prefetches();
6390 // @}
6391
6392 /** Innermost loop level of fused loop nest for this function stage.
6393 * Fusion runs from outermost to this loop level. The stages being fused
6394 * should not have producer/consumer relationship. See \ref Func::compute_with
6395 * and \ref Func::compute_with */
6396 // @{
6397 const FuseLoopLevel &fuse_level() const;
6398 FuseLoopLevel &fuse_level();
6399 // @}
6400
6401 /** List of function stages that are to be fused with this function stage
6402 * from the outermost loop to a certain loop level. Those function stages
6403 * are to be computed AFTER this function stage at the last fused loop level.
6404 * This list is populated when realization_order() is called. See
6405 * \ref Func::compute_with */
6406 // @{
6407 const std::vector<FusedPair> &fused_pairs() const;
6408 std::vector<FusedPair> &fused_pairs();
6409
6410 /** Are race conditions permitted? */
6411 // @{
6412 bool allow_race_conditions() const;
6413 bool &allow_race_conditions();
6414 // @}
6415
6416 /** Use atomic update? */
6417 // @{
6418 bool atomic() const;
6419 bool &atomic();
6420 // @}
6421
6422 /** Atomic updates are only allowed on associative reductions.
6423 * We try to prove the associativity, but the user can override
6424 * the associativity test and suppress compiler error if the prover
6425 * fails to recognize the associativity or the user does not care. */
6426 // @{
6427 bool override_atomic_associativity_test() const;
6428 bool &override_atomic_associativity_test();
6429 // @}
6430
6431 /** Pass an IRVisitor through to all Exprs referenced in the
6432 * Schedule. */
6433 void accept(IRVisitor *) const;
6434
6435 /** Pass an IRMutator through to all Exprs referenced in the
6436 * Schedule. */
6437 void mutate(IRMutator *);
6438};
6439
6440} // namespace Internal
6441} // namespace Halide
6442
6443#endif
6444
6445namespace Halide {
6446namespace Internal {
6447
6448struct ApplySplitResult {
6449 // If type is "Substitution", then this represents a substitution of
6450 // variable "name" to value. If type is "LetStmt", we should insert a new
6451 // let stmt defining "name" with value "value". If type is "Predicate", we
6452 // should ignore "name" and the predicate is "value".
6453
6454 std::string name;
6455 Expr value;
6456
6457 enum Type { Substitution = 0,
6458 LetStmt,
6459 Predicate };
6460 Type type;
6461
6462 ApplySplitResult(const std::string &n, Expr val, Type t)
6463 : name(n), value(std::move(val)), type(t) {
6464 }
6465 ApplySplitResult(Expr val)
6466 : name(""), value(std::move(val)), type(Predicate) {
6467 }
6468
6469 bool is_substitution() const {
6470 return (type == Substitution);
6471 }
6472 bool is_let() const {
6473 return (type == LetStmt);
6474 }
6475 bool is_predicate() const {
6476 return (type == Predicate);
6477 }
6478};
6479
6480/** Given a Split schedule on a definition (init or update), return a list of
6481 * of predicates on the definition, substitutions that needs to be applied to
6482 * the definition (in ascending order of application), and let stmts which
6483 * defined the values of variables referred by the predicates and substitutions
6484 * (ordered from innermost to outermost let). */
6485std::vector<ApplySplitResult> apply_split(
6486 const Split &split, bool is_update, const std::string &prefix,
6487 std::map<std::string, Expr> &dim_extent_alignment);
6488
6489/** Compute the loop bounds of the new dimensions resulting from applying the
6490 * split schedules using the loop bounds of the old dimensions. */
6491std::vector<std::pair<std::string, Expr>> compute_loop_bounds_after_split(
6492 const Split &split, const std::string &prefix);
6493
6494} // namespace Internal
6495} // namespace Halide
6496
6497#endif
6498#ifndef HALIDE_ARGUMENT_H
6499#define HALIDE_ARGUMENT_H
6500
6501/** \file
6502 * Defines a type used for expressing the type signature of a
6503 * generated halide pipeline
6504 */
6505
6506
6507namespace Halide {
6508
6509template<typename T>
6510class Buffer;
6511
6512struct ArgumentEstimates {
6513 /** If this is a scalar argument, then these are its default, min, max, and estimated values.
6514 * For buffer arguments, all should be undefined. */
6515 Expr scalar_def, scalar_min, scalar_max, scalar_estimate;
6516
6517 /** If this is a buffer argument, these are the estimated min and
6518 * extent for each dimension. If there are no estimates,
6519 * buffer_estimates.size() can be zero; otherwise, it must always
6520 * equal the dimensions */
6521 Region buffer_estimates;
6522
6523 bool operator==(const ArgumentEstimates &rhs) const;
6524};
6525
6526/**
6527 * A struct representing an argument to a halide-generated
6528 * function. Used for specifying the function signature of
6529 * generated code.
6530 */
6531struct Argument {
6532 /** The name of the argument */
6533 std::string name;
6534
6535 /** An argument is either a primitive type (for parameters), or a
6536 * buffer pointer.
6537 *
6538 * If kind == InputScalar, then type fully encodes the expected type
6539 * of the scalar argument.
6540 *
6541 * If kind == InputBuffer|OutputBuffer, then type.bytes() should be used
6542 * to determine* elem_size of the buffer; additionally, type.code *should*
6543 * reflect the expected interpretation of the buffer data (e.g. float vs int),
6544 * but there is no runtime enforcement of this at present.
6545 */
6546 enum Kind {
6547 InputScalar = halide_argument_kind_input_scalar,
6548 InputBuffer = halide_argument_kind_input_buffer,
6549 OutputBuffer = halide_argument_kind_output_buffer
6550 };
6551 Kind kind = InputScalar;
6552
6553 /** If kind == InputBuffer|OutputBuffer, this is the dimensionality of the buffer.
6554 * If kind == InputScalar, this value is ignored (and should always be set to zero) */
6555 uint8_t dimensions = 0;
6556
6557 /** If this is a scalar parameter, then this is its type.
6558 *
6559 * If this is a buffer parameter, this this is the type of its
6560 * elements.
6561 *
6562 * Note that type.lanes should always be 1 here. */
6563 Type type;
6564
6565 /* The estimates (if any) and default/min/max values (if any) for this Argument. */
6566 ArgumentEstimates argument_estimates;
6567
6568 Argument() = default;
6569 Argument(const std::string &_name, Kind _kind, const Type &_type, int _dimensions,
6570 const ArgumentEstimates &argument_estimates);
6571
6572 // Not explicit, so that you can put Buffer in an argument list,
6573 // to indicate that it shouldn't be baked into the object file,
6574 // but instead received as an argument at runtime
6575 template<typename T>
6576 Argument(Buffer<T> im)
6577 : name(im.name()),
6578 kind(InputBuffer),
6579 dimensions(im.dimensions()),
6580 type(im.type()) {
6581 }
6582
6583 bool is_buffer() const {
6584 return kind == InputBuffer || kind == OutputBuffer;
6585 }
6586 bool is_scalar() const {
6587 return kind == InputScalar;
6588 }
6589
6590 bool is_input() const {
6591 return kind == InputScalar || kind == InputBuffer;
6592 }
6593 bool is_output() const {
6594 return kind == OutputBuffer;
6595 }
6596
6597 bool operator==(const Argument &rhs) const {
6598 return name == rhs.name &&
6599 kind == rhs.kind &&
6600 dimensions == rhs.dimensions &&
6601 type == rhs.type &&
6602 argument_estimates == rhs.argument_estimates;
6603 }
6604};
6605
6606} // namespace Halide
6607
6608#endif
6609#ifndef HALIDE_ASSOCIATIVE_OPS_TABLE_H
6610#define HALIDE_ASSOCIATIVE_OPS_TABLE_H
6611
6612/** \file
6613 * Tables listing associative operators and their identities.
6614 */
6615
6616#ifndef HALIDE_IR_EQUALITY_H
6617#define HALIDE_IR_EQUALITY_H
6618
6619/** \file
6620 * Methods to test Exprs and Stmts for equality of value
6621 */
6622
6623
6624namespace Halide {
6625namespace Internal {
6626
6627/** A compare struct suitable for use in std::map and std::set that
6628 * computes a lexical ordering on IR nodes. */
6629struct IRDeepCompare {
6630 bool operator()(const Expr &a, const Expr &b) const;
6631 bool operator()(const Stmt &a, const Stmt &b) const;
6632};
6633
6634/** Lossily track known equal exprs with a cache. On collision, the
6635 * old pair is evicted. Used below by ExprWithCompareCache. */
6636class IRCompareCache {
6637private:
6638 struct Entry {
6639 Expr a, b;
6640 };
6641
6642 int bits;
6643
6644 uint32_t hash(const Expr &a, const Expr &b) const {
6645 // Note this hash is symmetric in a and b, so that a
6646 // comparison in a and b hashes to the same bucket as
6647 // a comparison on b and a.
6648 uint64_t pa = (uint64_t)(a.get());
6649 uint64_t pb = (uint64_t)(b.get());
6650 uint64_t mix = (pa + pb) + (pa ^ pb);
6651 mix ^= (mix >> bits);
6652 mix ^= (mix >> (bits * 2));
6653 uint32_t bottom = mix & ((1 << bits) - 1);
6654 return bottom;
6655 }
6656
6657 std::vector<Entry> entries;
6658
6659public:
6660 void insert(const Expr &a, const Expr &b) {
6661 uint32_t h = hash(a, b);
6662 entries[h].a = a;
6663 entries[h].b = b;
6664 }
6665
6666 bool contains(const Expr &a, const Expr &b) const {
6667 uint32_t h = hash(a, b);
6668 const Entry &e = entries[h];
6669 return ((a.same_as(e.a) && b.same_as(e.b)) ||
6670 (a.same_as(e.b) && b.same_as(e.a)));
6671 }
6672
6673 void clear() {
6674 for (size_t i = 0; i < entries.size(); i++) {
6675 entries[i].a = Expr();
6676 entries[i].b = Expr();
6677 }
6678 }
6679
6680 IRCompareCache() = default;
6681 IRCompareCache(int b)
6682 : bits(b), entries(static_cast<size_t>(1) << bits) {
6683 }
6684};
6685
6686/** A wrapper about Exprs so that they can be deeply compared with a
6687 * cache for known-equal subexpressions. Useful for unsanitized Exprs
6688 * coming in from the front-end, which may be horrible graphs with
6689 * sub-expressions that are equal by value but not by identity. This
6690 * isn't a comparison object like IRDeepCompare above, because libc++
6691 * requires that comparison objects be stateless (and constructs a new
6692 * one for each comparison!), so they can't have a cache associated
6693 * with them. However, by sneakily making the cache a mutable member
6694 * of the objects being compared, we can dodge this issue.
6695 *
6696 * Clunky example usage:
6697 *
6698\code
6699Expr a, b, c, query;
6700std::set<ExprWithCompareCache> s;
6701IRCompareCache cache(8);
6702s.insert(ExprWithCompareCache(a, &cache));
6703s.insert(ExprWithCompareCache(b, &cache));
6704s.insert(ExprWithCompareCache(c, &cache));
6705if (m.contains(ExprWithCompareCache(query, &cache))) {...}
6706\endcode
6707 *
6708 */
6709struct ExprWithCompareCache {
6710 Expr expr;
6711 mutable IRCompareCache *cache = nullptr;
6712
6713 ExprWithCompareCache() = default;
6714 ExprWithCompareCache(const Expr &e, IRCompareCache *c)
6715 : expr(e), cache(c) {
6716 }
6717
6718 /** The comparison uses (and updates) the cache */
6719 bool operator<(const ExprWithCompareCache &other) const;
6720};
6721
6722/** Compare IR nodes for equality of value. Traverses entire IR
6723 * tree. For equality of reference, use Expr::same_as. If you're
6724 * comparing non-CSE'd Exprs, use graph_equal, which is safe for nasty
6725 * graphs of IR nodes. */
6726// @{
6727bool equal(const Expr &a, const Expr &b);
6728bool equal(const Stmt &a, const Stmt &b);
6729bool graph_equal(const Expr &a, const Expr &b);
6730bool graph_equal(const Stmt &a, const Stmt &b);
6731// @}
6732
6733void ir_equality_test();
6734
6735} // namespace Internal
6736} // namespace Halide
6737
6738#endif
6739#ifndef HALIDE_IR_OPERATOR_H
6740#define HALIDE_IR_OPERATOR_H
6741
6742/** \file
6743 *
6744 * Defines various operator overloads and utility functions that make
6745 * it more pleasant to work with Halide expressions.
6746 */
6747
6748#include <cmath>
6749
6750#ifndef HALIDE_TUPLE_H
6751#define HALIDE_TUPLE_H
6752
6753/** \file
6754 *
6755 * Defines Tuple - the front-end handle on small arrays of expressions.
6756 */
6757#include <vector>
6758
6759
6760namespace Halide {
6761
6762class FuncRef;
6763
6764/** Create a small array of Exprs for defining and calling functions
6765 * with multiple outputs. */
6766class Tuple {
6767private:
6768 std::vector<Expr> exprs;
6769
6770public:
6771 /** The number of elements in the tuple. */
6772 size_t size() const {
6773 return exprs.size();
6774 }
6775
6776 /** Get a reference to an element. */
6777 Expr &operator[](size_t x) {
6778 user_assert(x < exprs.size()) << "Tuple access out of bounds\n";
6779 return exprs[x];
6780 }
6781
6782 /** Get a copy of an element. */
6783 Expr operator[](size_t x) const {
6784 user_assert(x < exprs.size()) << "Tuple access out of bounds\n";
6785 return exprs[x];
6786 }
6787
6788 /** Construct a Tuple of a single Expr */
6789 explicit Tuple(Expr e) {
6790 exprs.emplace_back(std::move(e));
6791 }
6792
6793 /** Construct a Tuple from some Exprs. */
6794 //@{
6795 template<typename... Args>
6796 Tuple(const Expr &a, const Expr &b, Args &&...args) {
6797 exprs = std::vector<Expr>{a, b, std::forward<Args>(args)...};
6798 }
6799 //@}
6800
6801 /** Construct a Tuple from a vector of Exprs */
6802 explicit HALIDE_NO_USER_CODE_INLINE Tuple(const std::vector<Expr> &e)
6803 : exprs(e) {
6804 user_assert(!e.empty()) << "Tuples must have at least one element\n";
6805 }
6806
6807 /** Construct a Tuple from a function reference. */
6808 Tuple(const FuncRef &);
6809
6810 /** Treat the tuple as a vector of Exprs */
6811 const std::vector<Expr> &as_vector() const {
6812 return exprs;
6813 }
6814};
6815
6816} // namespace Halide
6817
6818#endif
6819
6820namespace Halide {
6821
6822namespace Internal {
6823/** Is the expression either an IntImm, a FloatImm, a StringImm, or a
6824 * Cast of the same, or a Ramp or Broadcast of the same. Doesn't do
6825 * any constant folding. */
6826bool is_const(const Expr &e);
6827
6828/** Is the expression an IntImm, FloatImm of a particular value, or a
6829 * Cast, or Broadcast of the same. */
6830bool is_const(const Expr &e, int64_t v);
6831
6832/** If an expression is an IntImm or a Broadcast of an IntImm, return
6833 * a pointer to its value. Otherwise returns nullptr. */
6834const int64_t *as_const_int(const Expr &e);
6835
6836/** If an expression is a UIntImm or a Broadcast of a UIntImm, return
6837 * a pointer to its value. Otherwise returns nullptr. */
6838const uint64_t *as_const_uint(const Expr &e);
6839
6840/** If an expression is a FloatImm or a Broadcast of a FloatImm,
6841 * return a pointer to its value. Otherwise returns nullptr. */
6842const double *as_const_float(const Expr &e);
6843
6844/** Is the expression a constant integer power of two. Also returns
6845 * log base two of the expression if it is. Only returns true for
6846 * integer types. */
6847bool is_const_power_of_two_integer(const Expr &e, int *bits);
6848
6849/** Is the expression a const (as defined by is_const), and also
6850 * strictly greater than zero (in all lanes, if a vector expression) */
6851bool is_positive_const(const Expr &e);
6852
6853/** Is the expression a const (as defined by is_const), and also
6854 * strictly less than zero (in all lanes, if a vector expression) */
6855bool is_negative_const(const Expr &e);
6856
6857/** Is the expression an undef */
6858bool is_undef(const Expr &e);
6859
6860/** Is the expression a const (as defined by is_const), and also equal
6861 * to zero (in all lanes, if a vector expression) */
6862bool is_const_zero(const Expr &e);
6863
6864/** Is the expression a const (as defined by is_const), and also equal
6865 * to one (in all lanes, if a vector expression) */
6866bool is_const_one(const Expr &e);
6867
6868/** Is the statement a no-op (which we represent as either an
6869 * undefined Stmt, or as an Evaluate node of a constant) */
6870bool is_no_op(const Stmt &s);
6871
6872/** Does the expression
6873 * 1) Take on the same value no matter where it appears in a Stmt, and
6874 * 2) Evaluating it has no side-effects
6875 */
6876bool is_pure(const Expr &e);
6877
6878/** Construct an immediate of the given type from any numeric C++ type. */
6879// @{
6880Expr make_const(Type t, int64_t val);
6881Expr make_const(Type t, uint64_t val);
6882Expr make_const(Type t, double val);
6883inline Expr make_const(Type t, int32_t val) {
6884 return make_const(t, (int64_t)val);
6885}
6886inline Expr make_const(Type t, uint32_t val) {
6887 return make_const(t, (uint64_t)val);
6888}
6889inline Expr make_const(Type t, int16_t val) {
6890 return make_const(t, (int64_t)val);
6891}
6892inline Expr make_const(Type t, uint16_t val) {
6893 return make_const(t, (uint64_t)val);
6894}
6895inline Expr make_const(Type t, int8_t val) {
6896 return make_const(t, (int64_t)val);
6897}
6898inline Expr make_const(Type t, uint8_t val) {
6899 return make_const(t, (uint64_t)val);
6900}
6901inline Expr make_const(Type t, bool val) {
6902 return make_const(t, (uint64_t)val);
6903}
6904inline Expr make_const(Type t, float val) {
6905 return make_const(t, (double)val);
6906}
6907inline Expr make_const(Type t, float16_t val) {
6908 return make_const(t, (double)val);
6909}
6910// @}
6911
6912/** Construct a unique signed_integer_overflow Expr */
6913Expr make_signed_integer_overflow(Type type);
6914
6915/** Check if a constant value can be correctly represented as the given type. */
6916void check_representable(Type t, int64_t val);
6917
6918/** Construct a boolean constant from a C++ boolean value.
6919 * May also be a vector if width is given.
6920 * It is not possible to coerce a C++ boolean to Expr because
6921 * if we provide such a path then char objects can ambiguously
6922 * be converted to Halide Expr or to std::string. The problem
6923 * is that C++ does not have a real bool type - it is in fact
6924 * close enough to char that C++ does not know how to distinguish them.
6925 * make_bool is the explicit coercion. */
6926Expr make_bool(bool val, int lanes = 1);
6927
6928/** Construct the representation of zero in the given type */
6929Expr make_zero(Type t);
6930
6931/** Construct the representation of one in the given type */
6932Expr make_one(Type t);
6933
6934/** Construct the representation of two in the given type */
6935Expr make_two(Type t);
6936
6937/** Construct the constant boolean true. May also be a vector of
6938 * trues, if a lanes argument is given. */
6939Expr const_true(int lanes = 1);
6940
6941/** Construct the constant boolean false. May also be a vector of
6942 * falses, if a lanes argument is given. */
6943Expr const_false(int lanes = 1);
6944
6945/** Attempt to cast an expression to a smaller type while provably not
6946 * losing information. If it can't be done, return an undefined
6947 * Expr. */
6948Expr lossless_cast(Type t, Expr e);
6949
6950/** Attempt to negate x without introducing new IR and without overflow.
6951 * If it can't be done, return an undefined Expr. */
6952Expr lossless_negate(const Expr &x);
6953
6954/** Coerce the two expressions to have the same type, using C-style
6955 * casting rules. For the purposes of casting, a boolean type is
6956 * UInt(1). We use the following procedure:
6957 *
6958 * If the types already match, do nothing.
6959 *
6960 * Then, if one type is a vector and the other is a scalar, the scalar
6961 * is broadcast to match the vector width, and we continue.
6962 *
6963 * Then, if one type is floating-point and the other is not, the
6964 * non-float is cast to the floating-point type, and we're done.
6965 *
6966 * Then, if both types are unsigned ints, the one with fewer bits is
6967 * cast to match the one with more bits and we're done.
6968 *
6969 * Then, if both types are signed ints, the one with fewer bits is
6970 * cast to match the one with more bits and we're done.
6971 *
6972 * Finally, if one type is an unsigned int and the other type is a signed
6973 * int, both are cast to a signed int with the greater of the two
6974 * bit-widths. For example, matching an Int(8) with a UInt(16) results
6975 * in an Int(16).
6976 *
6977 */
6978void match_types(Expr &a, Expr &b);
6979
6980/** Asserts that both expressions are integer types and are either
6981 * both signed or both unsigned. If one argument is scalar and the
6982 * other a vector, the scalar is broadcasted to have the same number
6983 * of lanes as the vector. If one expression is of narrower type than
6984 * the other, it is widened to the bit width of the wider. */
6985void match_types_bitwise(Expr &a, Expr &b, const char *op_name);
6986
6987/** Halide's vectorizable transcendentals. */
6988// @{
6989Expr halide_log(const Expr &a);
6990Expr halide_exp(const Expr &a);
6991Expr halide_erf(const Expr &a);
6992// @}
6993
6994/** Raise an expression to an integer power by repeatedly multiplying
6995 * it by itself. */
6996Expr raise_to_integer_power(Expr a, int64_t b);
6997
6998/** Split a boolean condition into vector of ANDs. If 'cond' is undefined,
6999 * return an empty vector. */
7000void split_into_ands(const Expr &cond, std::vector<Expr> &result);
7001
7002/** A builder to help create Exprs representing halide_buffer_t
7003 * structs (e.g. foo.buffer) via calls to halide_buffer_init. Fill out
7004 * the fields and then call build. The resulting Expr will be a call
7005 * to halide_buffer_init with the struct members as arguments. If the
7006 * buffer_memory field is undefined, it uses a call to alloca to make
7007 * some stack memory for the buffer. If the shape_memory field is
7008 * undefined, it similarly uses stack memory for the shape. If the
7009 * shape_memory field is null, it uses the dim field already in the
7010 * buffer. Other unitialized fields will take on a value of zero in
7011 * the constructed buffer. */
7012struct BufferBuilder {
7013 Expr buffer_memory, shape_memory;
7014 Expr host, device, device_interface;
7015 Type type;
7016 int dimensions = 0;
7017 std::vector<Expr> mins, extents, strides;
7018 Expr host_dirty, device_dirty;
7019 Expr build() const;
7020};
7021
7022/** If e is a ramp expression with stride, default 1, return the base,
7023 * otherwise undefined. */
7024Expr strided_ramp_base(const Expr &e, int stride = 1);
7025
7026/** Implementations of division and mod that are specific to Halide.
7027 * Use these implementations; do not use native C division or mod to
7028 * simplify Halide expressions. Halide division and modulo satisify
7029 * the Euclidean definition of division for integers a and b:
7030 *
7031 /code
7032 when b != 0, (a/b)*b + a%b = a
7033 0 <= a%b < |b|
7034 /endcode
7035 *
7036 * Additionally, mod by zero returns zero, and div by zero returns
7037 * zero. This makes mod and div total functions.
7038 */
7039// @{
7040template<typename T>
7041inline T mod_imp(T a, T b) {
7042 Type t = type_of<T>();
7043 if (!t.is_float() && b == 0) {
7044 return 0;
7045 } else if (t.is_int()) {
7046 int64_t ia = a;
7047 int64_t ib = b;
7048 int64_t a_neg = ia >> 63;
7049 int64_t b_neg = ib >> 63;
7050 int64_t b_zero = (ib == 0) ? -1 : 0;
7051 ia -= a_neg;
7052 int64_t r = ia % (ib | b_zero);
7053 r += (a_neg & ((ib ^ b_neg) + ~b_neg));
7054 r &= ~b_zero;
7055 return r;
7056 } else {
7057 return a % b;
7058 }
7059}
7060
7061template<typename T>
7062inline T div_imp(T a, T b) {
7063 Type t = type_of<T>();
7064 if (!t.is_float() && b == 0) {
7065 return (T)0;
7066 } else if (t.is_int()) {
7067 // Do it as 64-bit
7068 int64_t ia = a;
7069 int64_t ib = b;
7070 int64_t a_neg = ia >> 63;
7071 int64_t b_neg = ib >> 63;
7072 int64_t b_zero = (ib == 0) ? -1 : 0;
7073 ib -= b_zero;
7074 ia -= a_neg;
7075 int64_t q = ia / ib;
7076 q += a_neg & (~b_neg - b_neg);
7077 q &= ~b_zero;
7078 return (T)q;
7079 } else {
7080 return a / b;
7081 }
7082}
7083// @}
7084
7085// Special cases for float, double.
7086template<>
7087inline float mod_imp<float>(float a, float b) {
7088 float f = a - b * (floorf(a / b));
7089 // The remainder has the same sign as b.
7090 return f;
7091}
7092template<>
7093inline double mod_imp<double>(double a, double b) {
7094 double f = a - b * (std::floor(a / b));
7095 return f;
7096}
7097
7098template<>
7099inline float div_imp<float>(float a, float b) {
7100 return a / b;
7101}
7102template<>
7103inline double div_imp<double>(double a, double b) {
7104 return a / b;
7105}
7106
7107/** Return an Expr that is identical to the input Expr, but with
7108 * all calls to likely() and likely_if_innermost() removed. */
7109Expr remove_likelies(const Expr &e);
7110
7111/** Return a Stmt that is identical to the input Stmt, but with
7112 * all calls to likely() and likely_if_innermost() removed. */
7113Stmt remove_likelies(const Stmt &s);
7114
7115/** If the expression is a tag helper call, remove it and return
7116 * the tagged expression. If not, returns the expression. */
7117Expr unwrap_tags(const Expr &e);
7118
7119/** Expressions tagged with this intrinsic are suggestions that
7120 * vectorization of loops with guard ifs should be implemented with
7121 * non-faulting predicated loads and stores, instead of scalarizing
7122 * an if statement. */
7123Expr predicate(Expr e);
7124
7125// Secondary args to print can be Exprs or const char *
7126inline HALIDE_NO_USER_CODE_INLINE void collect_print_args(std::vector<Expr> &args) {
7127}
7128
7129template<typename... Args>
7130inline HALIDE_NO_USER_CODE_INLINE void collect_print_args(std::vector<Expr> &args, const char *arg, Args &&...more_args) {
7131 args.emplace_back(std::string(arg));
7132 collect_print_args(args, std::forward<Args>(more_args)...);
7133}
7134
7135template<typename... Args>
7136inline HALIDE_NO_USER_CODE_INLINE void collect_print_args(std::vector<Expr> &args, Expr arg, Args &&...more_args) {
7137 args.push_back(std::move(arg));
7138 collect_print_args(args, std::forward<Args>(more_args)...);
7139}
7140
7141Expr requirement_failed_error(Expr condition, const std::vector<Expr> &args);
7142
7143Expr memoize_tag_helper(Expr result, const std::vector<Expr> &cache_key_values);
7144
7145/** Compute widen(a) + widen(b). The result is always signed. */
7146Expr widening_add(Expr a, Expr b);
7147/** Compute widen(a) * widen(b). a and b may have different signedness. */
7148Expr widening_mul(Expr a, Expr b);
7149/** Compute widen(a) - widen(b). The result is always signed. */
7150Expr widening_sub(Expr a, Expr b);
7151/** Compute widen(a) << b. */
7152Expr widening_shift_left(Expr a, Expr b);
7153Expr widening_shift_left(Expr a, int b);
7154/** Compute widen(a) >> b. */
7155Expr widening_shift_right(Expr a, Expr b);
7156Expr widening_shift_right(Expr a, int b);
7157
7158/** Compute saturating_add(a, (1 >> min(b, 0)) / 2) << b. When b is positive
7159 * indicating a left shift, the rounding term is zero. */
7160Expr rounding_shift_left(Expr a, Expr b);
7161Expr rounding_shift_left(Expr a, int b);
7162/** Compute saturating_add(a, (1 << max(b, 0)) / 2) >> b. When b is negative
7163 * indicating a left shift, the rounding term is zero. */
7164Expr rounding_shift_right(Expr a, Expr b);
7165Expr rounding_shift_right(Expr a, int b);
7166
7167/** Compute saturating_narrow(widen(a) + widen(b)) */
7168Expr saturating_add(Expr a, Expr b);
7169/** Compute saturating_narrow(widen(a) - widen(b)) */
7170Expr saturating_sub(Expr a, Expr b);
7171
7172/** Compute narrow((widen(a) + widen(b)) / 2) */
7173Expr halving_add(Expr a, Expr b);
7174/** Compute narrow((widen(a) + widen(b) + 1) / 2) */
7175Expr rounding_halving_add(Expr a, Expr b);
7176/** Compute narrow((widen(a) - widen(b)) / 2) */
7177Expr halving_sub(Expr a, Expr b);
7178/** Compute narrow((widen(a) - widen(b) + 1) / 2) */
7179Expr rounding_halving_sub(Expr a, Expr b);
7180
7181/** Compute saturating_narrow(shift_right(widening_mul(a, b), q)) */
7182Expr mul_shift_right(Expr a, Expr b, Expr q);
7183Expr mul_shift_right(Expr a, Expr b, int q);
7184/** Compute saturating_narrow(rounding_shift_right(widening_mul(a, b), q)) */
7185Expr rounding_mul_shift_right(Expr a, Expr b, Expr q);
7186Expr rounding_mul_shift_right(Expr a, Expr b, int q);
7187
7188} // namespace Internal
7189
7190/** Cast an expression to the halide type corresponding to the C++ type T. */
7191template<typename T>
7192inline Expr cast(Expr a) {
7193 return cast(type_of<T>(), std::move(a));
7194}
7195
7196/** Cast an expression to a new type. */
7197Expr cast(Type t, Expr a);
7198
7199/** Return the sum of two expressions, doing any necessary type
7200 * coercion using \ref Internal::match_types */
7201Expr operator+(Expr a, Expr b);
7202
7203/** Add an expression and a constant integer. Coerces the type of the
7204 * integer to match the type of the expression. Errors if the integer
7205 * cannot be represented in the type of the expression. */
7206// @{
7207Expr operator+(Expr a, int b);
7208
7209/** Add a constant integer and an expression. Coerces the type of the
7210 * integer to match the type of the expression. Errors if the integer
7211 * cannot be represented in the type of the expression. */
7212Expr operator+(int a, Expr b);
7213
7214/** Modify the first expression to be the sum of two expressions,
7215 * without changing its type. This casts the second argument to match
7216 * the type of the first. */
7217Expr &operator+=(Expr &a, Expr b);
7218
7219/** Return the difference of two expressions, doing any necessary type
7220 * coercion using \ref Internal::match_types */
7221Expr operator-(Expr a, Expr b);
7222
7223/** Subtracts a constant integer from an expression. Coerces the type of the
7224 * integer to match the type of the expression. Errors if the integer
7225 * cannot be represented in the type of the expression. */
7226Expr operator-(Expr a, int b);
7227
7228/** Subtracts an expression from a constant integer. Coerces the type
7229 * of the integer to match the type of the expression. Errors if the
7230 * integer cannot be represented in the type of the expression. */
7231Expr operator-(int a, Expr b);
7232
7233/** Return the negative of the argument. Does no type casting, so more
7234 * formally: return that number which when added to the original,
7235 * yields zero of the same type. For unsigned integers the negative is
7236 * still an unsigned integer. E.g. in UInt(8), the negative of 56 is
7237 * 200, because 56 + 200 == 0 */
7238Expr operator-(Expr a);
7239
7240/** Modify the first expression to be the difference of two expressions,
7241 * without changing its type. This casts the second argument to match
7242 * the type of the first. */
7243Expr &operator-=(Expr &a, Expr b);
7244
7245/** Return the product of two expressions, doing any necessary type
7246 * coercion using \ref Internal::match_types */
7247Expr operator*(Expr a, Expr b);
7248
7249/** Multiply an expression and a constant integer. Coerces the type of the
7250 * integer to match the type of the expression. Errors if the integer
7251 * cannot be represented in the type of the expression. */
7252Expr operator*(Expr a, int b);
7253
7254/** Multiply a constant integer and an expression. Coerces the type of
7255 * the integer to match the type of the expression. Errors if the
7256 * integer cannot be represented in the type of the expression. */
7257Expr operator*(int a, Expr b);
7258
7259/** Modify the first expression to be the product of two expressions,
7260 * without changing its type. This casts the second argument to match
7261 * the type of the first. */
7262Expr &operator*=(Expr &a, Expr b);
7263
7264/** Return the ratio of two expressions, doing any necessary type
7265 * coercion using \ref Internal::match_types. Note that integer
7266 * division in Halide is not the same as integer division in C-like
7267 * languages in two ways.
7268 *
7269 * First, signed integer division in Halide rounds according to the
7270 * sign of the denominator. This means towards minus infinity for
7271 * positive denominators, and towards positive infinity for negative
7272 * denominators. This is unlike C, which rounds towards zero. This
7273 * decision ensures that upsampling expressions like f(x/2, y/2) don't
7274 * have funny discontinuities when x and y cross zero.
7275 *
7276 * Second, division by zero returns zero instead of faulting. For
7277 * types where overflow is defined behavior, division of the largest
7278 * negative signed integer by -1 returns the larged negative signed
7279 * integer for the type (i.e. it wraps). This ensures that a division
7280 * operation can never have a side-effect, which is helpful in Halide
7281 * because scheduling directives can expand the domain of computation
7282 * of a Func, potentially introducing new zero-division.
7283 */
7284Expr operator/(Expr a, Expr b);
7285
7286/** Modify the first expression to be the ratio of two expressions,
7287 * without changing its type. This casts the second argument to match
7288 * the type of the first. Note that signed integer division in Halide
7289 * rounds towards minus infinity, unlike C, which rounds towards
7290 * zero. */
7291Expr &operator/=(Expr &a, Expr b);
7292
7293/** Divides an expression by a constant integer. Coerces the type
7294 * of the integer to match the type of the expression. Errors if the
7295 * integer cannot be represented in the type of the expression. */
7296Expr operator/(Expr a, int b);
7297
7298/** Divides a constant integer by an expression. Coerces the type
7299 * of the integer to match the type of the expression. Errors if the
7300 * integer cannot be represented in the type of the expression. */
7301Expr operator/(int a, Expr b);
7302
7303/** Return the first argument reduced modulo the second, doing any
7304 * necessary type coercion using \ref Internal::match_types. There are
7305 * two key differences between C-like languages and Halide for the
7306 * modulo operation, which complement the way division works.
7307 *
7308 * First, the result is never negative, so x % 2 is always zero or
7309 * one, unlike in C-like languages. x % -2 is equivalent, and is also
7310 * always zero or one. Second, mod by zero evaluates to zero (unlike
7311 * in C, where it faults). This makes modulo, like division, a
7312 * side-effect-free operation. */
7313Expr operator%(Expr a, Expr b);
7314
7315/** Mods an expression by a constant integer. Coerces the type
7316 * of the integer to match the type of the expression. Errors if the
7317 * integer cannot be represented in the type of the expression. */
7318Expr operator%(Expr a, int b);
7319
7320/** Mods a constant integer by an expression. Coerces the type
7321 * of the integer to match the type of the expression. Errors if the
7322 * integer cannot be represented in the type of the expression. */
7323Expr operator%(int a, Expr b);
7324
7325/** Return a boolean expression that tests whether the first argument
7326 * is greater than the second, after doing any necessary type coercion
7327 * using \ref Internal::match_types */
7328Expr operator>(Expr a, Expr b);
7329
7330/** Return a boolean expression that tests whether an expression is
7331 * greater than a constant integer. Coerces the integer to the type of
7332 * the expression. Errors if the integer is not representable in that
7333 * type. */
7334Expr operator>(Expr a, int b);
7335
7336/** Return a boolean expression that tests whether a constant integer is
7337 * greater than an expression. Coerces the integer to the type of
7338 * the expression. Errors if the integer is not representable in that
7339 * type. */
7340Expr operator>(int a, Expr b);
7341
7342/** Return a boolean expression that tests whether the first argument
7343 * is less than the second, after doing any necessary type coercion
7344 * using \ref Internal::match_types */
7345Expr operator<(Expr a, Expr b);
7346
7347/** Return a boolean expression that tests whether an expression is
7348 * less than a constant integer. Coerces the integer to the type of
7349 * the expression. Errors if the integer is not representable in that
7350 * type. */
7351Expr operator<(Expr a, int b);
7352
7353/** Return a boolean expression that tests whether a constant integer is
7354 * less than an expression. Coerces the integer to the type of
7355 * the expression. Errors if the integer is not representable in that
7356 * type. */
7357Expr operator<(int a, Expr b);
7358
7359/** Return a boolean expression that tests whether the first argument
7360 * is less than or equal to the second, after doing any necessary type
7361 * coercion using \ref Internal::match_types */
7362Expr operator<=(Expr a, Expr b);
7363
7364/** Return a boolean expression that tests whether an expression is
7365 * less than or equal to a constant integer. Coerces the integer to
7366 * the type of the expression. Errors if the integer is not
7367 * representable in that type. */
7368Expr operator<=(Expr a, int b);
7369
7370/** Return a boolean expression that tests whether a constant integer
7371 * is less than or equal to an expression. Coerces the integer to the
7372 * type of the expression. Errors if the integer is not representable
7373 * in that type. */
7374Expr operator<=(int a, Expr b);
7375
7376/** Return a boolean expression that tests whether the first argument
7377 * is greater than or equal to the second, after doing any necessary
7378 * type coercion using \ref Internal::match_types */
7379Expr operator>=(Expr a, Expr b);
7380
7381/** Return a boolean expression that tests whether an expression is
7382 * greater than or equal to a constant integer. Coerces the integer to
7383 * the type of the expression. Errors if the integer is not
7384 * representable in that type. */
7385Expr operator>=(const Expr &a, int b);
7386
7387/** Return a boolean expression that tests whether a constant integer
7388 * is greater than or equal to an expression. Coerces the integer to the
7389 * type of the expression. Errors if the integer is not representable
7390 * in that type. */
7391Expr operator>=(int a, const Expr &b);
7392
7393/** Return a boolean expression that tests whether the first argument
7394 * is equal to the second, after doing any necessary type coercion
7395 * using \ref Internal::match_types */
7396Expr operator==(Expr a, Expr b);
7397
7398/** Return a boolean expression that tests whether an expression is
7399 * equal to a constant integer. Coerces the integer to the type of the
7400 * expression. Errors if the integer is not representable in that
7401 * type. */
7402Expr operator==(Expr a, int b);
7403
7404/** Return a boolean expression that tests whether a constant integer
7405 * is equal to an expression. Coerces the integer to the type of the
7406 * expression. Errors if the integer is not representable in that
7407 * type. */
7408Expr operator==(int a, Expr b);
7409
7410/** Return a boolean expression that tests whether the first argument
7411 * is not equal to the second, after doing any necessary type coercion
7412 * using \ref Internal::match_types */
7413Expr operator!=(Expr a, Expr b);
7414
7415/** Return a boolean expression that tests whether an expression is
7416 * not equal to a constant integer. Coerces the integer to the type of
7417 * the expression. Errors if the integer is not representable in that
7418 * type. */
7419Expr operator!=(Expr a, int b);
7420
7421/** Return a boolean expression that tests whether a constant integer
7422 * is not equal to an expression. Coerces the integer to the type of
7423 * the expression. Errors if the integer is not representable in that
7424 * type. */
7425Expr operator!=(int a, Expr b);
7426
7427/** Returns the logical and of the two arguments */
7428Expr operator&&(Expr a, Expr b);
7429
7430/** Logical and of an Expr and a bool. Either returns the Expr or an
7431 * Expr representing false, depending on the bool. */
7432// @{
7433Expr operator&&(Expr a, bool b);
7434Expr operator&&(bool a, Expr b);
7435// @}
7436
7437/** Returns the logical or of the two arguments */
7438Expr operator||(Expr a, Expr b);
7439
7440/** Logical or of an Expr and a bool. Either returns the Expr or an
7441 * Expr representing true, depending on the bool. */
7442// @{
7443Expr operator||(Expr a, bool b);
7444Expr operator||(bool a, Expr b);
7445// @}
7446
7447/** Returns the logical not the argument */
7448Expr operator!(Expr a);
7449
7450/** Returns an expression representing the greater of the two
7451 * arguments, after doing any necessary type coercion using
7452 * \ref Internal::match_types. Vectorizes cleanly on most platforms
7453 * (with the exception of integer types on x86 without SSE4). */
7454Expr max(Expr a, Expr b);
7455
7456/** Returns an expression representing the greater of an expression
7457 * and a constant integer. The integer is coerced to the type of the
7458 * expression. Errors if the integer is not representable as that
7459 * type. Vectorizes cleanly on most platforms (with the exception of
7460 * integer types on x86 without SSE4). */
7461Expr max(Expr a, int b);
7462
7463/** Returns an expression representing the greater of a constant
7464 * integer and an expression. The integer is coerced to the type of
7465 * the expression. Errors if the integer is not representable as that
7466 * type. Vectorizes cleanly on most platforms (with the exception of
7467 * integer types on x86 without SSE4). */
7468Expr max(int a, Expr b);
7469
7470inline Expr max(float a, Expr b) {
7471 return max(Expr(a), std::move(b));
7472}
7473inline Expr max(Expr a, float b) {
7474 return max(std::move(a), Expr(b));
7475}
7476
7477/** Returns an expression representing the greater of an expressions
7478 * vector, after doing any necessary type coersion using
7479 * \ref Internal::match_types. Vectorizes cleanly on most platforms
7480 * (with the exception of integer types on x86 without SSE4).
7481 * The expressions are folded from right ie. max(.., max(.., ..)).
7482 * The arguments can be any mix of types but must all be convertible to Expr. */
7483template<typename A, typename B, typename C, typename... Rest,
7484 typename std::enable_if<Halide::Internal::all_are_convertible<Expr, Rest...>::value>::type * = nullptr>
7485inline Expr max(A &&a, B &&b, C &&c, Rest &&...rest) {
7486 return max(std::forward<A>(a), max(std::forward<B>(b), std::forward<C>(c), std::forward<Rest>(rest)...));
7487}
7488
7489Expr min(Expr a, Expr b);
7490
7491/** Returns an expression representing the lesser of an expression
7492 * and a constant integer. The integer is coerced to the type of the
7493 * expression. Errors if the integer is not representable as that
7494 * type. Vectorizes cleanly on most platforms (with the exception of
7495 * integer types on x86 without SSE4). */
7496Expr min(Expr a, int b);
7497
7498/** Returns an expression representing the lesser of a constant
7499 * integer and an expression. The integer is coerced to the type of
7500 * the expression. Errors if the integer is not representable as that
7501 * type. Vectorizes cleanly on most platforms (with the exception of
7502 * integer types on x86 without SSE4). */
7503Expr min(int a, Expr b);
7504
7505inline Expr min(float a, Expr b) {
7506 return min(Expr(a), std::move(b));
7507}
7508inline Expr min(Expr a, float b) {
7509 return min(std::move(a), Expr(b));
7510}
7511
7512/** Returns an expression representing the lesser of an expressions
7513 * vector, after doing any necessary type coersion using
7514 * \ref Internal::match_types. Vectorizes cleanly on most platforms
7515 * (with the exception of integer types on x86 without SSE4).
7516 * The expressions are folded from right ie. min(.., min(.., ..)).
7517 * The arguments can be any mix of types but must all be convertible to Expr. */
7518template<typename A, typename B, typename C, typename... Rest,
7519 typename std::enable_if<Halide::Internal::all_are_convertible<Expr, Rest...>::value>::type * = nullptr>
7520inline Expr min(A &&a, B &&b, C &&c, Rest &&...rest) {
7521 return min(std::forward<A>(a), min(std::forward<B>(b), std::forward<C>(c), std::forward<Rest>(rest)...));
7522}
7523
7524/** Operators on floats treats those floats as Exprs. Making these
7525 * explicit prevents implicit float->int casts that might otherwise
7526 * occur. */
7527// @{
7528inline Expr operator+(Expr a, float b) {
7529 return std::move(a) + Expr(b);
7530}
7531inline Expr operator+(float a, Expr b) {
7532 return Expr(a) + std::move(b);
7533}
7534inline Expr operator-(Expr a, float b) {
7535 return std::move(a) - Expr(b);
7536}
7537inline Expr operator-(float a, Expr b) {
7538 return Expr(a) - std::move(b);
7539}
7540inline Expr operator*(Expr a, float b) {
7541 return std::move(a) * Expr(b);
7542}
7543inline Expr operator*(float a, Expr b) {
7544 return Expr(a) * std::move(b);
7545}
7546inline Expr operator/(Expr a, float b) {
7547 return std::move(a) / Expr(b);
7548}
7549inline Expr operator/(float a, Expr b) {
7550 return Expr(a) / std::move(b);
7551}
7552inline Expr operator%(Expr a, float b) {
7553 return std::move(a) % Expr(b);
7554}
7555inline Expr operator%(float a, Expr b) {
7556 return Expr(a) % std::move(b);
7557}
7558inline Expr operator>(Expr a, float b) {
7559 return std::move(a) > Expr(b);
7560}
7561inline Expr operator>(float a, Expr b) {
7562 return Expr(a) > std::move(b);
7563}
7564inline Expr operator<(Expr a, float b) {
7565 return std::move(a) < Expr(b);
7566}
7567inline Expr operator<(float a, Expr b) {
7568 return Expr(a) < std::move(b);
7569}
7570inline Expr operator>=(Expr a, float b) {
7571 return std::move(a) >= Expr(b);
7572}
7573inline Expr operator>=(float a, Expr b) {
7574 return Expr(a) >= std::move(b);
7575}
7576inline Expr operator<=(Expr a, float b) {
7577 return std::move(a) <= Expr(b);
7578}
7579inline Expr operator<=(float a, Expr b) {
7580 return Expr(a) <= std::move(b);
7581}
7582inline Expr operator==(Expr a, float b) {
7583 return std::move(a) == Expr(b);
7584}
7585inline Expr operator==(float a, Expr b) {
7586 return Expr(a) == std::move(b);
7587}
7588inline Expr operator!=(Expr a, float b) {
7589 return std::move(a) != Expr(b);
7590}
7591inline Expr operator!=(float a, Expr b) {
7592 return Expr(a) != std::move(b);
7593}
7594// @}
7595
7596/** Clamps an expression to lie within the given bounds. The bounds
7597 * are type-cast to match the expression. Vectorizes as well as min/max. */
7598Expr clamp(Expr a, const Expr &min_val, const Expr &max_val);
7599
7600/** Returns the absolute value of a signed integer or floating-point
7601 * expression. Vectorizes cleanly. Unlike in C, abs of a signed
7602 * integer returns an unsigned integer of the same bit width. This
7603 * means that abs of the most negative integer doesn't overflow. */
7604Expr abs(Expr a);
7605
7606/** Return the absolute difference between two values. Vectorizes
7607 * cleanly. Returns an unsigned value of the same bit width. There are
7608 * various ways to write this yourself, but they contain numerous
7609 * gotchas and don't always compile to good code, so use this
7610 * instead. */
7611Expr absd(Expr a, Expr b);
7612
7613/** Returns an expression similar to the ternary operator in C, except
7614 * that it always evaluates all arguments. If the first argument is
7615 * true, then return the second, else return the third. Typically
7616 * vectorizes cleanly, but benefits from SSE41 or newer on x86. */
7617Expr select(Expr condition, Expr true_value, Expr false_value);
7618
7619/** A multi-way variant of select similar to a switch statement in C,
7620 * which can accept multiple conditions and values in pairs. Evaluates
7621 * to the first value for which the condition is true. Returns the
7622 * final value if all conditions are false. */
7623template<typename... Args,
7624 typename std::enable_if<Halide::Internal::all_are_convertible<Expr, Args...>::value>::type * = nullptr>
7625inline Expr select(Expr c0, Expr v0, Expr c1, Expr v1, Args &&...args) {
7626 return select(std::move(c0), std::move(v0), select(std::move(c1), std::move(v1), std::forward<Args>(args)...));
7627}
7628
7629/** Equivalent of ternary select(), but taking/returning tuples. If the condition is
7630 * a Tuple, it must match the size of the true and false Tuples. */
7631// @{
7632Tuple tuple_select(const Tuple &condition, const Tuple &true_value, const Tuple &false_value);
7633Tuple tuple_select(const Expr &condition, const Tuple &true_value, const Tuple &false_value);
7634// @}
7635
7636/** Equivalent of multiway select(), but taking/returning tuples. If the condition is
7637 * a Tuple, it must match the size of the true and false Tuples. */
7638// @{
7639template<typename... Args>
7640inline Tuple tuple_select(const Tuple &c0, const Tuple &v0, const Tuple &c1, const Tuple &v1, Args &&...args) {
7641 return tuple_select(c0, v0, tuple_select(c1, v1, std::forward<Args>(args)...));
7642}
7643
7644template<typename... Args>
7645inline Tuple tuple_select(const Expr &c0, const Tuple &v0, const Expr &c1, const Tuple &v1, Args &&...args) {
7646 return tuple_select(c0, v0, tuple_select(c1, v1, std::forward<Args>(args)...));
7647}
7648// @}
7649
7650/** Oftentimes we want to pack a list of expressions with the same type
7651 * into a channel dimension, e.g.,
7652 * img(x, y, c) = select(c == 0, 100, // Red
7653 * c == 1, 50, // Green
7654 * 25); // Blue
7655 * This is tedious when the list is long. The following function
7656 * provide convinent syntax that allow one to write:
7657 * img(x, y, c) = mux(c, {100, 50, 25});
7658 *
7659 * As with the select equivalent, if the first argument (the index) is
7660 * out of range, the expression evaluates to the last value.
7661 */
7662// @{
7663Expr mux(const Expr &id, const std::initializer_list<Expr> &values);
7664Expr mux(const Expr &id, const std::vector<Expr> &values);
7665Expr mux(const Expr &id, const Tuple &values);
7666// @}
7667
7668/** Return the sine of a floating-point expression. If the argument is
7669 * not floating-point, it is cast to Float(32). Does not vectorize
7670 * well. */
7671Expr sin(Expr x);
7672
7673/** Return the arcsine of a floating-point expression. If the argument
7674 * is not floating-point, it is cast to Float(32). Does not vectorize
7675 * well. */
7676Expr asin(Expr x);
7677
7678/** Return the cosine of a floating-point expression. If the argument
7679 * is not floating-point, it is cast to Float(32). Does not vectorize
7680 * well. */
7681Expr cos(Expr x);
7682
7683/** Return the arccosine of a floating-point expression. If the
7684 * argument is not floating-point, it is cast to Float(32). Does not
7685 * vectorize well. */
7686Expr acos(Expr x);
7687
7688/** Return the tangent of a floating-point expression. If the argument
7689 * is not floating-point, it is cast to Float(32). Does not vectorize
7690 * well. */
7691Expr tan(Expr x);
7692
7693/** Return the arctangent of a floating-point expression. If the
7694 * argument is not floating-point, it is cast to Float(32). Does not
7695 * vectorize well. */
7696Expr atan(Expr x);
7697
7698/** Return the angle of a floating-point gradient. If the argument is
7699 * not floating-point, it is cast to Float(32). Does not vectorize
7700 * well. */
7701Expr atan2(Expr y, Expr x);
7702
7703/** Return the hyperbolic sine of a floating-point expression. If the
7704 * argument is not floating-point, it is cast to Float(32). Does not
7705 * vectorize well. */
7706Expr sinh(Expr x);
7707
7708/** Return the hyperbolic arcsinhe of a floating-point expression. If
7709 * the argument is not floating-point, it is cast to Float(32). Does
7710 * not vectorize well. */
7711Expr asinh(Expr x);
7712
7713/** Return the hyperbolic cosine of a floating-point expression. If
7714 * the argument is not floating-point, it is cast to Float(32). Does
7715 * not vectorize well. */
7716Expr cosh(Expr x);
7717
7718/** Return the hyperbolic arccosine of a floating-point expression.
7719 * If the argument is not floating-point, it is cast to
7720 * Float(32). Does not vectorize well. */
7721Expr acosh(Expr x);
7722
7723/** Return the hyperbolic tangent of a floating-point expression. If
7724 * the argument is not floating-point, it is cast to Float(32). Does
7725 * not vectorize well. */
7726Expr tanh(Expr x);
7727
7728/** Return the hyperbolic arctangent of a floating-point expression.
7729 * If the argument is not floating-point, it is cast to
7730 * Float(32). Does not vectorize well. */
7731Expr atanh(Expr x);
7732
7733/** Return the square root of a floating-point expression. If the
7734 * argument is not floating-point, it is cast to Float(32). Typically
7735 * vectorizes cleanly. */
7736Expr sqrt(Expr x);
7737
7738/** Return the square root of the sum of the squares of two
7739 * floating-point expressions. If the argument is not floating-point,
7740 * it is cast to Float(32). Vectorizes cleanly. */
7741Expr hypot(const Expr &x, const Expr &y);
7742
7743/** Return the exponential of a floating-point expression. If the
7744 * argument is not floating-point, it is cast to Float(32). For
7745 * Float(64) arguments, this calls the system exp function, and does
7746 * not vectorize well. For Float(32) arguments, this function is
7747 * vectorizable, does the right thing for extremely small or extremely
7748 * large inputs, and is accurate up to the last bit of the
7749 * mantissa. Vectorizes cleanly. */
7750Expr exp(Expr x);
7751
7752/** Return the logarithm of a floating-point expression. If the
7753 * argument is not floating-point, it is cast to Float(32). For
7754 * Float(64) arguments, this calls the system log function, and does
7755 * not vectorize well. For Float(32) arguments, this function is
7756 * vectorizable, does the right thing for inputs <= 0 (returns -inf or
7757 * nan), and is accurate up to the last bit of the
7758 * mantissa. Vectorizes cleanly. */
7759Expr log(Expr x);
7760
7761/** Return one floating point expression raised to the power of
7762 * another. The type of the result is given by the type of the first
7763 * argument. If the first argument is not a floating-point type, it is
7764 * cast to Float(32). For Float(32), cleanly vectorizable, and
7765 * accurate up to the last few bits of the mantissa. Gets worse when
7766 * approaching overflow. Vectorizes cleanly. */
7767Expr pow(Expr x, Expr y);
7768
7769/** Evaluate the error function erf. Only available for
7770 * Float(32). Accurate up to the last three bits of the
7771 * mantissa. Vectorizes cleanly. */
7772Expr erf(const Expr &x);
7773
7774/** Fast vectorizable approximation to some trigonometric functions for Float(32).
7775 * Absolute approximation error is less than 1e-5. */
7776// @{
7777Expr fast_sin(const Expr &x);
7778Expr fast_cos(const Expr &x);
7779// @}
7780
7781/** Fast approximate cleanly vectorizable log for Float(32). Returns
7782 * nonsense for x <= 0.0f. Accurate up to the last 5 bits of the
7783 * mantissa. Vectorizes cleanly. */
7784Expr fast_log(const Expr &x);
7785
7786/** Fast approximate cleanly vectorizable exp for Float(32). Returns
7787 * nonsense for inputs that would overflow or underflow. Typically
7788 * accurate up to the last 5 bits of the mantissa. Gets worse when
7789 * approaching overflow. Vectorizes cleanly. */
7790Expr fast_exp(const Expr &x);
7791
7792/** Fast approximate cleanly vectorizable pow for Float(32). Returns
7793 * nonsense for x < 0.0f. Accurate up to the last 5 bits of the
7794 * mantissa for typical exponents. Gets worse when approaching
7795 * overflow. Vectorizes cleanly. */
7796Expr fast_pow(Expr x, Expr y);
7797
7798/** Fast approximate inverse for Float(32). Corresponds to the rcpps
7799 * instruction on x86, and the vrecpe instruction on ARM. Vectorizes
7800 * cleanly. Note that this can produce slightly different results
7801 * across different implementations of the same architecture (e.g. AMD vs Intel),
7802 * even when strict_float is enabled. */
7803Expr fast_inverse(Expr x);
7804
7805/** Fast approximate inverse square root for Float(32). Corresponds to
7806 * the rsqrtps instruction on x86, and the vrsqrte instruction on
7807 * ARM. Vectorizes cleanly. Note that this can produce slightly different results
7808 * across different implementations of the same architecture (e.g. AMD vs Intel),
7809 * even when strict_float is enabled. */
7810Expr fast_inverse_sqrt(Expr x);
7811
7812/** Return the greatest whole number less than or equal to a
7813 * floating-point expression. If the argument is not floating-point,
7814 * it is cast to Float(32). The return value is still in floating
7815 * point, despite being a whole number. Vectorizes cleanly. */
7816Expr floor(Expr x);
7817
7818/** Return the least whole number greater than or equal to a
7819 * floating-point expression. If the argument is not floating-point,
7820 * it is cast to Float(32). The return value is still in floating
7821 * point, despite being a whole number. Vectorizes cleanly. */
7822Expr ceil(Expr x);
7823
7824/** Return the whole number closest to a floating-point expression. If the
7825 * argument is not floating-point, it is cast to Float(32). The return value
7826 * is still in floating point, despite being a whole number. On ties, we
7827 * follow IEEE754 conventions and round to the nearest even number. Vectorizes
7828 * cleanly. */
7829Expr round(Expr x);
7830
7831/** Return the integer part of a floating-point expression. If the argument is
7832 * not floating-point, it is cast to Float(32). The return value is still in
7833 * floating point, despite being a whole number. Vectorizes cleanly. */
7834Expr trunc(Expr x);
7835
7836/** Returns true if the argument is a Not a Number (NaN). Requires a
7837 * floating point argument. Vectorizes cleanly.
7838 * Note that the Expr passed in will be evaluated in strict_float mode,
7839 * regardless of whether strict_float mode is enabled in the current Target. */
7840Expr is_nan(Expr x);
7841
7842/** Returns true if the argument is Inf or -Inf. Requires a
7843 * floating point argument. Vectorizes cleanly.
7844 * Note that the Expr passed in will be evaluated in strict_float mode,
7845 * regardless of whether strict_float mode is enabled in the current Target. */
7846Expr is_inf(Expr x);
7847
7848/** Returns true if the argument is a finite value (ie, neither NaN nor Inf).
7849 * Requires a floating point argument. Vectorizes cleanly.
7850 * Note that the Expr passed in will be evaluated in strict_float mode,
7851 * regardless of whether strict_float mode is enabled in the current Target. */
7852Expr is_finite(Expr x);
7853
7854/** Return the fractional part of a floating-point expression. If the argument
7855 * is not floating-point, it is cast to Float(32). The return value has the
7856 * same sign as the original expression. Vectorizes cleanly. */
7857Expr fract(const Expr &x);
7858
7859/** Reinterpret the bits of one value as another type. */
7860Expr reinterpret(Type t, Expr e);
7861
7862template<typename T>
7863Expr reinterpret(Expr e) {
7864 return reinterpret(type_of<T>(), e);
7865}
7866
7867/** Return the bitwise and of two expressions (which need not have the
7868 * same type). The result type is the wider of the two expressions.
7869 * Only integral types are allowed and both expressions must be signed
7870 * or both must be unsigned. */
7871Expr operator&(Expr x, Expr y);
7872
7873/** Return the bitwise and of an expression and an integer. The type
7874 * of the result is the type of the expression argument. */
7875// @{
7876Expr operator&(Expr x, int y);
7877Expr operator&(int x, Expr y);
7878// @}
7879
7880/** Return the bitwise or of two expressions (which need not have the
7881 * same type). The result type is the wider of the two expressions.
7882 * Only integral types are allowed and both expressions must be signed
7883 * or both must be unsigned. */
7884Expr operator|(Expr x, Expr y);
7885
7886/** Return the bitwise or of an expression and an integer. The type of
7887 * the result is the type of the expression argument. */
7888// @{
7889Expr operator|(Expr x, int y);
7890Expr operator|(int x, Expr y);
7891// @}
7892
7893/** Return the bitwise xor of two expressions (which need not have the
7894 * same type). The result type is the wider of the two expressions.
7895 * Only integral types are allowed and both expressions must be signed
7896 * or both must be unsigned. */
7897Expr operator^(Expr x, Expr y);
7898
7899/** Return the bitwise xor of an expression and an integer. The type
7900 * of the result is the type of the expression argument. */
7901// @{
7902Expr operator^(Expr x, int y);
7903Expr operator^(int x, Expr y);
7904// @}
7905
7906/** Return the bitwise not of an expression. */
7907Expr operator~(Expr x);
7908
7909/** Shift the bits of an integer value left. This is actually less
7910 * efficient than multiplying by 2^n, because Halide's optimization
7911 * passes understand multiplication, and will compile it to
7912 * shifting. This operator is only for if you really really need bit
7913 * shifting (e.g. because the exponent is a run-time parameter). The
7914 * type of the result is equal to the type of the first argument. Both
7915 * arguments must have integer type. */
7916// @{
7917Expr operator<<(Expr x, Expr y);
7918Expr operator<<(Expr x, int y);
7919// @}
7920
7921/** Shift the bits of an integer value right. Does sign extension for
7922 * signed integers. This is less efficient than dividing by a power of
7923 * two. Halide's definition of division (always round to negative
7924 * infinity) means that all divisions by powers of two get compiled to
7925 * bit-shifting, and Halide's optimization routines understand
7926 * division and can work with it. The type of the result is equal to
7927 * the type of the first argument. Both arguments must have integer
7928 * type. */
7929// @{
7930Expr operator>>(Expr x, Expr y);
7931Expr operator>>(Expr x, int y);
7932// @}
7933
7934/** Linear interpolate between the two values according to a weight.
7935 * \param zero_val The result when weight is 0
7936 * \param one_val The result when weight is 1
7937 * \param weight The interpolation amount
7938 *
7939 * Both zero_val and one_val must have the same type. All types are
7940 * supported, including bool.
7941 *
7942 * The weight is treated as its own type and must be float or an
7943 * unsigned integer type. It is scaled to the bit-size of the type of
7944 * x and y if they are integer, or converted to float if they are
7945 * float. Integer weights are converted to float via division by the
7946 * full-range value of the weight's type. Floating-point weights used
7947 * to interpolate between integer values must be between 0.0f and
7948 * 1.0f, and an error may be signaled if it is not provably so. (clamp
7949 * operators can be added to provide proof. Currently an error is only
7950 * signalled for constant weights.)
7951 *
7952 * For integer linear interpolation, out of range values cannot be
7953 * represented. In particular, weights that are conceptually less than
7954 * 0 or greater than 1.0 are not representable. As such the result is
7955 * always between x and y (inclusive of course). For lerp with
7956 * floating-point values and floating-point weight, the full range of
7957 * a float is valid, however underflow and overflow can still occur.
7958 *
7959 * Ordering is not required between zero_val and one_val:
7960 * lerp(42, 69, .5f) == lerp(69, 42, .5f) == 56
7961 *
7962 * Results for integer types are for exactly rounded arithmetic. As
7963 * such, there are cases where 16-bit and float differ because 32-bit
7964 * floating-point (float) does not have enough precision to produce
7965 * the exact result. (Likely true for 32-bit integer
7966 * vs. double-precision floating-point as well.)
7967 *
7968 * At present, double precision and 64-bit integers are not supported.
7969 *
7970 * Generally, lerp will vectorize as if it were an operation on a type
7971 * twice the bit size of the inferred type for x and y.
7972 *
7973 * Some examples:
7974 * \code
7975 *
7976 * // Since Halide does not have direct type delcarations, casts
7977 * // below are used to indicate the types of the parameters.
7978 * // Such casts not required or expected in actual code where types
7979 * // are inferred.
7980 *
7981 * lerp(cast<float>(x), cast<float>(y), cast<float>(w)) ->
7982 * x * (1.0f - w) + y * w
7983 *
7984 * lerp(cast<uint8_t>(x), cast<uint8_t>(y), cast<uint8_t>(w)) ->
7985 * cast<uint8_t>(cast<uint8_t>(x) * (1.0f - cast<uint8_t>(w) / 255.0f) +
7986 * cast<uint8_t>(y) * cast<uint8_t>(w) / 255.0f + .5f)
7987 *
7988 * // Note addition in Halide promoted uint8_t + int8_t to int16_t already,
7989 * // the outer cast is added for clarity.
7990 * lerp(cast<uint8_t>(x), cast<int8_t>(y), cast<uint8_t>(w)) ->
7991 * cast<int16_t>(cast<uint8_t>(x) * (1.0f - cast<uint8_t>(w) / 255.0f) +
7992 * cast<int8_t>(y) * cast<uint8_t>(w) / 255.0f + .5f)
7993 *
7994 * lerp(cast<int8_t>(x), cast<int8_t>(y), cast<float>(w)) ->
7995 * cast<int8_t>(cast<int8_t>(x) * (1.0f - cast<float>(w)) +
7996 * cast<int8_t>(y) * cast<uint8_t>(w))
7997 *
7998 * \endcode
7999 * */
8000Expr lerp(Expr zero_val, Expr one_val, Expr weight);
8001
8002/** Count the number of set bits in an expression. */
8003Expr popcount(Expr x);
8004
8005/** Count the number of leading zero bits in an expression. If the expression is
8006 * zero, the result is the number of bits in the type. */
8007Expr count_leading_zeros(Expr x);
8008
8009/** Count the number of trailing zero bits in an expression. If the expression is
8010 * zero, the result is the number of bits in the type. */
8011Expr count_trailing_zeros(Expr x);
8012
8013/** Divide two integers, rounding towards zero. This is the typical
8014 * behavior of most hardware architectures, which differs from
8015 * Halide's division operator, which is Euclidean (rounds towards
8016 * -infinity). Will throw a runtime error if y is zero, or if y is -1
8017 * and x is the minimum signed integer. */
8018Expr div_round_to_zero(Expr x, Expr y);
8019
8020/** Compute the remainder of dividing two integers, when division is
8021 * rounding toward zero. This is the typical behavior of most hardware
8022 * architectures, which differs from Halide's mod operator, which is
8023 * Euclidean (produces the remainder when division rounds towards
8024 * -infinity). Will throw a runtime error if y is zero. */
8025Expr mod_round_to_zero(Expr x, Expr y);
8026
8027/** Return a random variable representing a uniformly distributed
8028 * float in the half-open interval [0.0f, 1.0f). For random numbers of
8029 * other types, use lerp with a random float as the last parameter.
8030 *
8031 * Optionally takes a seed.
8032 *
8033 * Note that:
8034 \code
8035 Expr x = random_float();
8036 Expr y = x + x;
8037 \endcode
8038 *
8039 * is very different to
8040 *
8041 \code
8042 Expr y = random_float() + random_float();
8043 \endcode
8044 *
8045 * The first doubles a random variable, and the second adds two
8046 * independent random variables.
8047 *
8048 * A given random variable takes on a unique value that depends
8049 * deterministically on the pure variables of the function they belong
8050 * to, the identity of the function itself, and which definition of
8051 * the function it is used in. They are, however, shared across tuple
8052 * elements.
8053 *
8054 * This function vectorizes cleanly.
8055 */
8056Expr random_float(Expr seed = Expr());
8057
8058/** Return a random variable representing a uniformly distributed
8059 * unsigned 32-bit integer. See \ref random_float. Vectorizes cleanly. */
8060Expr random_uint(Expr seed = Expr());
8061
8062/** Return a random variable representing a uniformly distributed
8063 * 32-bit integer. See \ref random_float. Vectorizes cleanly. */
8064Expr random_int(Expr seed = Expr());
8065
8066/** Create an Expr that prints out its value whenever it is
8067 * evaluated. It also prints out everything else in the arguments
8068 * list, separated by spaces. This can include string literals. */
8069//@{
8070Expr print(const std::vector<Expr> &values);
8071
8072template<typename... Args>
8073inline HALIDE_NO_USER_CODE_INLINE Expr print(Expr a, Args &&...args) {
8074 std::vector<Expr> collected_args = {std::move(a)};
8075 Internal::collect_print_args(collected_args, std::forward<Args>(args)...);
8076 return print(collected_args);
8077}
8078//@}
8079
8080/** Create an Expr that prints whenever it is evaluated, provided that
8081 * the condition is true. */
8082// @{
8083Expr print_when(Expr condition, const std::vector<Expr> &values);
8084
8085template<typename... Args>
8086inline HALIDE_NO_USER_CODE_INLINE Expr print_when(Expr condition, Expr a, Args &&...args) {
8087 std::vector<Expr> collected_args = {std::move(a)};
8088 Internal::collect_print_args(collected_args, std::forward<Args>(args)...);
8089 return print_when(std::move(condition), collected_args);
8090}
8091
8092// @}
8093
8094/** Create an Expr that that guarantees a precondition.
8095 * If 'condition' is true, the return value is equal to the first Expr.
8096 * If 'condition' is false, halide_error() is called, and the return value
8097 * is arbitrary. Any additional arguments after the first Expr are stringified
8098 * and passed as a user-facing message to halide_error(), similar to print().
8099 *
8100 * Note that this essentially *always* inserts a runtime check into the
8101 * generated code (except when the condition can be proven at compile time);
8102 * as such, it should be avoided inside inner loops, except for debugging
8103 * or testing purposes. Note also that it does not vectorize cleanly (vector
8104 * values will be scalarized for the check).
8105 *
8106 * However, using this to make assertions about (say) input values
8107 * can be useful, both in terms of correctness and (potentially) in terms
8108 * of code generation, e.g.
8109 \code
8110 Param<int> p;
8111 Expr y = require(p > 0, p);
8112 \endcode
8113 * will allow the optimizer to assume positive, nonzero values for y.
8114 */
8115// @{
8116Expr require(Expr condition, const std::vector<Expr> &values);
8117
8118template<typename... Args>
8119inline HALIDE_NO_USER_CODE_INLINE Expr require(Expr condition, Expr value, Args &&...args) {
8120 std::vector<Expr> collected_args = {std::move(value)};
8121 Internal::collect_print_args(collected_args, std::forward<Args>(args)...);
8122 return require(std::move(condition), collected_args);
8123}
8124// @}
8125
8126/** Return an undef value of the given type. Halide skips stores that
8127 * depend on undef values, so you can use this to mean "do not modify
8128 * this memory location". This is an escape hatch that can be used for
8129 * several things:
8130 *
8131 * You can define a reduction with no pure step, by setting the pure
8132 * step to undef. Do this only if you're confident that the update
8133 * steps are sufficient to correctly fill in the domain.
8134 *
8135 * For a tuple-valued reduction, you can write an update step that
8136 * only updates some tuple elements.
8137 *
8138 * You can define single-stage pipeline that only has update steps,
8139 * and depends on the values already in the output buffer.
8140 *
8141 * Use this feature with great caution, as you can use it to load from
8142 * uninitialized memory.
8143 */
8144Expr undef(Type t);
8145
8146template<typename T>
8147inline Expr undef() {
8148 return undef(type_of<T>());
8149}
8150
8151namespace Internal {
8152
8153/** Return an expression that should never be evaluated. Expressions
8154 * that depend on unreachabale values are also unreachable, and
8155 * statements that execute unreachable expressions are also considered
8156 * unreachable. */
8157Expr unreachable(Type t = Int(32));
8158
8159template<typename T>
8160inline Expr unreachable() {
8161 return unreachable(type_of<T>());
8162}
8163
8164} // namespace Internal
8165
8166/** Control the values used in the memoization cache key for memoize.
8167 * Normally parameters and other external dependencies are
8168 * automatically inferred and added to the cache key. The memoize_tag
8169 * operator allows computing one expression and using either the
8170 * computed value, or one or more other expressions in the cache key
8171 * instead of the parameter dependencies of the computation. The
8172 * single argument version is completely safe in that the cache key
8173 * will use the actual computed value -- it is difficult or imposible
8174 * to produce erroneous caching this way. The more-than-one argument
8175 * version allows generating cache keys that do not uniquely identify
8176 * the computation and thus can result in caching errors.
8177 *
8178 * A potential use for the single argument version is to handle a
8179 * floating-point parameter that is quantized to a small
8180 * integer. Mutliple values of the float will produce the same integer
8181 * and moving the caching to using the integer for the key is more
8182 * efficient.
8183 *
8184 * The main use for the more-than-one argument version is to provide
8185 * cache key information for Handles and ImageParams, which otherwise
8186 * are not allowed inside compute_cached operations. E.g. when passing
8187 * a group of parameters to an external array function via a Handle,
8188 * memoize_tag can be used to isolate the actual values used by that
8189 * computation. If an ImageParam is a constant image with a persistent
8190 * digest, memoize_tag can be used to key computations using that image
8191 * on the digest. */
8192// @{
8193template<typename... Args>
8194inline HALIDE_NO_USER_CODE_INLINE Expr memoize_tag(Expr result, Args &&...args) {
8195 std::vector<Expr> collected_args{std::forward<Args>(args)...};
8196 return Internal::memoize_tag_helper(std::move(result), collected_args);
8197}
8198// @}
8199
8200/** Expressions tagged with this intrinsic are considered to be part
8201 * of the steady state of some loop with a nasty beginning and end
8202 * (e.g. a boundary condition). When Halide encounters likely
8203 * intrinsics, it splits the containing loop body into three, and
8204 * tries to simplify down all conditions that lead to the likely. For
8205 * example, given the expression: select(x < 1, bar, x > 10, bar,
8206 * likely(foo)), Halide will split the loop over x into portions where
8207 * x < 1, 1 <= x <= 10, and x > 10.
8208 *
8209 * You're unlikely to want to call this directly. You probably want to
8210 * use the boundary condition helpers in the BoundaryConditions
8211 * namespace instead.
8212 */
8213Expr likely(Expr e);
8214
8215/** Equivalent to likely, but only triggers a loop partitioning if
8216 * found in an innermost loop. */
8217Expr likely_if_innermost(Expr e);
8218
8219/** Cast an expression to the halide type corresponding to the C++
8220 * type T. As part of the cast, clamp to the minimum and maximum
8221 * values of the result type. */
8222template<typename T>
8223Expr saturating_cast(Expr e) {
8224 return saturating_cast(type_of<T>(), std::move(e));
8225}
8226
8227/** Cast an expression to a new type, clamping to the minimum and
8228 * maximum values of the result type. */
8229Expr saturating_cast(Type t, Expr e);
8230
8231/** Makes a best effort attempt to preserve IEEE floating-point
8232 * semantics in evaluating an expression. May not be implemented for
8233 * all backends. (E.g. it is difficult to do this for C++ code
8234 * generation as it depends on the compiler flags used to compile the
8235 * generated code. */
8236Expr strict_float(Expr e);
8237
8238/** Create an Expr that that promises another Expr is clamped but do
8239 * not generate code to check the assertion or modify the value. No
8240 * attempt is made to prove the bound at compile time. (If it is
8241 * proved false as a result of something else, an error might be
8242 * generated, but it is also possible the compiler will crash.) The
8243 * promised bound is used in bounds inference so it will allow
8244 * satisfying bounds checks as well as possibly aiding optimization.
8245 *
8246 * unsafe_promise_clamped returns its first argument, the Expr 'value'
8247 *
8248 * This is a very easy way to make Halide generate erroneous code if
8249 * the bound promises is not kept. Use sparingly when there is no
8250 * other way to convey the information to the compiler and it is
8251 * required for a valuable optimization.
8252 *
8253 * Unsafe promises can be checked by turning on
8254 * Target::CheckUnsafePromises. This is intended for debugging only.
8255 */
8256Expr unsafe_promise_clamped(const Expr &value, const Expr &min, const Expr &max);
8257
8258namespace Internal {
8259/**
8260 * FOR INTERNAL USE ONLY.
8261 *
8262 * An entirely unchecked version of unsafe_promise_clamped, used
8263 * inside the compiler as an annotation of the known bounds of an Expr
8264 * when it has proved something is bounded and wants to record that
8265 * fact for later passes (notably bounds inference) to exploit. This
8266 * gets introduced by GuardWithIf tail strategies, because the bounds
8267 * machinery has a hard time exploiting if statement conditions.
8268 *
8269 * Unlike unsafe_promise_clamped, this expression is
8270 * context-dependent, because 'value' might be statically bounded at
8271 * some point in the IR (e.g. due to a containing if statement), but
8272 * not elsewhere.
8273 **/
8274Expr promise_clamped(const Expr &value, const Expr &min, const Expr &max);
8275} // namespace Internal
8276
8277/** Scatter and gather are used for update definition which must store
8278 * multiple values to distinct locations at the same time. The
8279 * multiple expressions on the right-hand-side are bundled together
8280 * into a "gather", which must match a "scatter" the the same number
8281 * of arguments on the left-hand-size. For example, to store the
8282 * values 1 and 2 to the locations (x, y, 3) and (x, y, 4),
8283 * respectively:
8284 *
8285\code
8286f(x, y, scatter(3, 4)) = gather(1, 2);
8287\endcode
8288 *
8289 * The result of gather or scatter can be treated as an
8290 * expression. Any containing operations on it can be assumed to
8291 * distribute over the elements. If two gather expressions are
8292 * combined with an arithmetic operator (e.g. added), they combine
8293 * element-wise. The following example stores the values 2 * x, 2 * y,
8294 * and 2 * c to the locations (x + 1, y, c), (x, y + 3, c), and (x, y,
8295 * c + 2) respectively:
8296 *
8297\code
8298f(x + scatter(1, 0, 0), y + scatter(0, 3, 0), c + scatter(0, 0, 2)) = 2 * gather(x, y, c);
8299\endcode
8300*
8301* Repeated values in the scatter cause multiple stores to the same
8302* location. The stores happen in order from left to right, so the
8303* rightmost value wins. The following code is equivalent to f(x) = 5
8304*
8305\code
8306f(scatter(x, x)) = gather(3, 5);
8307\endcode
8308*
8309* Gathers are most useful for algorithms which require in-place
8310* swapping or permutation of multiple elements, or other kinds of
8311* in-place mutations that require loading multiple inputs, doing some
8312* operations to them jointly, then storing them again. The following
8313* update definition swaps the values of f at locations 3 and 5 if an
8314* input parameter p is true:
8315*
8316\code
8317f(scatter(3, 5)) = f(select(p, gather(5, 3), gather(3, 5)));
8318\endcode
8319*
8320* For more examples of the use of scatter and gather, see
8321* test/correctness/multiple_scatter.cpp
8322*
8323* It is not currently possible to use scatter and gather to write an
8324* update definition in which the *number* of values loaded or stored
8325* varies, as the size of the scatter/gather packet must be fixed a
8326* compile-time. A workaround is to make the unwanted extra operations
8327* a redundant copy of the last operation, which will be
8328* dead-code-eliminated by the compiler. For example, the following
8329* update definition swaps the values at locations 3 and 5 when the
8330* parameter p is true, and rotates the values at locations 1, 2, and 3
8331* when it is false. The load from 3 and store to 5 will be redundantly
8332* repeated:
8333*
8334\code
8335f(select(p, scatter(3, 5, 5), scatter(1, 2, 3))) = f(select(p, gather(5, 3, 3), gather(2, 3, 1)));
8336\endcode
8337*
8338* Note that in the p == true case, we redudantly load from 3 and write
8339* to 5 twice.
8340*/
8341//@{
8342Expr scatter(const std::vector<Expr> &args);
8343Expr gather(const std::vector<Expr> &args);
8344
8345template<typename... Args>
8346Expr scatter(const Expr &e, Args &&...args) {
8347 return scatter({e, std::forward<Args>(args)...});
8348}
8349
8350template<typename... Args>
8351Expr gather(const Expr &e, Args &&...args) {
8352 return gather({e, std::forward<Args>(args)...});
8353}
8354// @}
8355
8356} // namespace Halide
8357
8358#endif
8359
8360#include <utility>
8361#include <vector>
8362
8363namespace Halide {
8364namespace Internal {
8365
8366/**
8367 * Represent an associative op with its identity. The op may be multi-dimensional,
8368 * e.g. complex multiplication. 'is_commutative' is set to true if the op is also
8369 * commutative in addition to being associative.
8370 *
8371 * For example, complex multiplication is represented as:
8372 \code
8373 AssociativePattern pattern(
8374 {x0 * y0 - x1 * y1, x1 * y0 + x0 * y1},
8375 {one, zero},
8376 true
8377 );
8378 \endcode
8379 */
8380struct AssociativePattern {
8381 /** Contain the binary operators for each dimension of the associative op. */
8382 std::vector<Expr> ops;
8383 /** Contain the identities for each dimension of the associative op. */
8384 std::vector<Expr> identities;
8385 /** Indicate if the associative op is also commutative. */
8386 bool is_commutative = false;
8387
8388 AssociativePattern() = default;
8389 AssociativePattern(size_t size)
8390 : ops(size), identities(size) {
8391 }
8392 AssociativePattern(const std::vector<Expr> &ops, const std::vector<Expr> &ids, bool is_commutative)
8393 : ops(ops), identities(ids), is_commutative(is_commutative) {
8394 }
8395 AssociativePattern(Expr op, Expr id, bool is_commutative)
8396 : ops({std::move(op)}), identities({std::move(id)}), is_commutative(is_commutative) {
8397 }
8398
8399 bool operator==(const AssociativePattern &other) const {
8400 if ((is_commutative != other.is_commutative) || (ops.size() != other.ops.size())) {
8401 return false;
8402 }
8403 for (size_t i = 0; i < size(); ++i) {
8404 if (!equal(ops[i], other.ops[i]) || !equal(identities[i], other.identities[i])) {
8405 return false;
8406 }
8407 }
8408 return true;
8409 }
8410 bool operator!=(const AssociativePattern &other) const {
8411 return !(*this == other);
8412 }
8413 size_t size() const {
8414 return ops.size();
8415 }
8416 bool commutative() const {
8417 return is_commutative;
8418 }
8419};
8420
8421const std::vector<AssociativePattern> &get_ops_table(const std::vector<Expr> &exprs);
8422
8423} // namespace Internal
8424} // namespace Halide
8425
8426#endif
8427#ifndef HALIDE_ASSOCIATIVITY_H
8428#define HALIDE_ASSOCIATIVITY_H
8429
8430/** \file
8431 *
8432 * Methods for extracting an associative operator from a Func's update definition
8433 * if there is any and computing the identity of the associative operator.
8434 */
8435#include <string>
8436#include <vector>
8437
8438
8439namespace Halide {
8440namespace Internal {
8441
8442/**
8443 * Represent the equivalent associative op of an update definition.
8444 * For example, the following associative Expr, min(f(x), g(r.x) + 2),
8445 * where f(x) is the self-recurrence term, is represented as:
8446 \code
8447 AssociativeOp assoc(
8448 AssociativePattern(min(x, y), +inf, true),
8449 {Replacement("x", f(x))},
8450 {Replacement("y", g(r.x) + 2)},
8451 true
8452 );
8453 \endcode
8454 *
8455 * 'pattern' contains the list of equivalent binary/unary operators (+ identities)
8456 * for each Tuple element in the update definition. 'pattern' also contains
8457 * a boolean that indicates if the op is also commutative. 'xs' and 'ys'
8458 * contain the corresponding definition of each variable in the list of
8459 * binary operators.
8460 *
8461 * For unary operator, 'xs' is not set, i.e. it will be a pair of empty string
8462 * and undefined Expr: {"", Expr()}. 'pattern' will only contain the 'y' term in
8463 * this case. For example, min(g(r.x), 4), will be represented as:
8464 \code
8465 AssociativeOp assoc(
8466 AssociativePattern(y, 0, false),
8467 {Replacement("", Expr())},
8468 {Replacement("y", min(g(r.x), 4))},
8469 true
8470 );
8471 \endcode
8472 *
8473 * Self-assignment, f(x) = f(x), will be represented as:
8474 \code
8475 AssociativeOp assoc(
8476 AssociativePattern(x, 0, true),
8477 {Replacement("x", f(x))},
8478 {Replacement("", Expr())},
8479 true
8480 );
8481 \endcode
8482 * For both unary operator and self-assignment cases, the identity does not
8483 * matter. It can be anything.
8484 */
8485struct AssociativeOp {
8486 struct Replacement {
8487 /** Variable name that is used to replace the expr in 'op'. */
8488 std::string var;
8489 Expr expr;
8490
8491 Replacement() = default;
8492 Replacement(const std::string &var, Expr expr)
8493 : var(var), expr(std::move(expr)) {
8494 }
8495
8496 bool operator==(const Replacement &other) const {
8497 return (var == other.var) && equal(expr, other.expr);
8498 }
8499 bool operator!=(const Replacement &other) const {
8500 return !(*this == other);
8501 }
8502 };
8503
8504 /** List of pairs of binary associative op and its identity. */
8505 AssociativePattern pattern;
8506 std::vector<Replacement> xs;
8507 std::vector<Replacement> ys;
8508 bool is_associative = false;
8509
8510 AssociativeOp() = default;
8511 AssociativeOp(size_t size)
8512 : pattern(size), xs(size), ys(size) {
8513 }
8514 AssociativeOp(const AssociativePattern &p, const std::vector<Replacement> &xs,
8515 const std::vector<Replacement> &ys, bool is_associative)
8516 : pattern(p), xs(xs), ys(ys), is_associative(is_associative) {
8517 }
8518
8519 bool associative() const {
8520 return is_associative;
8521 }
8522 bool commutative() const {
8523 return pattern.is_commutative;
8524 }
8525 size_t size() const {
8526 return pattern.size();
8527 }
8528};
8529
8530/**
8531 * Given an update definition of a Func 'f', determine its equivalent
8532 * associative binary/unary operator if there is any. 'is_associative'
8533 * indicates if the operation was successfuly proven as associative.
8534 */
8535AssociativeOp prove_associativity(
8536 const std::string &f, std::vector<Expr> args, std::vector<Expr> exprs);
8537
8538void associativity_test();
8539
8540} // namespace Internal
8541} // namespace Halide
8542
8543#endif
8544#ifndef HALIDE_ASYNC_PRODUCERS_H
8545#define HALIDE_ASYNC_PRODUCERS_H
8546
8547/** \file
8548 * Defines the lowering pass that injects task parallelism for producers that are scheduled as async.
8549 */
8550#include <map>
8551#include <string>
8552
8553
8554namespace Halide {
8555namespace Internal {
8556
8557class Function;
8558
8559Stmt fork_async_producers(Stmt s, const std::map<std::string, Function> &env);
8560
8561} // namespace Internal
8562} // namespace Halide
8563
8564#endif
8565#ifndef HALIDE_INTERNAL_AUTO_SCHEDULE_UTILS_H
8566#define HALIDE_INTERNAL_AUTO_SCHEDULE_UTILS_H
8567
8568/** \file
8569 *
8570 * Defines util functions that used by auto scheduler.
8571 */
8572
8573#include <limits>
8574#include <set>
8575
8576#ifndef HALIDE_DEFINITION_H
8577#define HALIDE_DEFINITION_H
8578
8579/** \file
8580 * Defines the internal representation of a halide function's definition and related classes
8581 */
8582
8583
8584#include <map>
8585
8586namespace Halide {
8587
8588namespace Internal {
8589struct DefinitionContents;
8590struct FunctionContents;
8591class ReductionDomain;
8592} // namespace Internal
8593
8594namespace Internal {
8595
8596class IRVisitor;
8597class IRMutator;
8598struct Specialization;
8599
8600/** A Function definition which can either represent a init or an update
8601 * definition. A function may have different definitions due to specialization,
8602 * which are stored in 'specializations' (Not possible from the front-end, but
8603 * some scheduling directives may potentially cause this divergence to occur).
8604 * Although init definition may have multiple values (RHS) per specialization, it
8605 * must have the same LHS (i.e. same pure dimension variables). The update
8606 * definition, on the other hand, may have different LHS/RHS per specialization.
8607 * Note that, while the Expr in LHS/RHS may be different across specializations,
8608 * they must have the same number of dimensions and the same pure dimensions.
8609 */
8610class Definition {
8611
8612 IntrusivePtr<DefinitionContents> contents;
8613
8614public:
8615 /** Construct a Definition from an existing DefinitionContents pointer. Must be non-null */
8616 explicit Definition(const IntrusivePtr<DefinitionContents> &);
8617
8618 /** Construct a Definition with the supplied args, values, and reduction domain. */
8619 Definition(const std::vector<Expr> &args, const std::vector<Expr> &values,
8620 const ReductionDomain &rdom, bool is_init);
8621
8622 /** Construct an undefined Definition object. */
8623 Definition();
8624
8625 /** Return a copy of this Definition. */
8626 Definition get_copy() const;
8627
8628 /** Equality of identity */
8629 bool same_as(const Definition &other) const {
8630 return contents.same_as(other.contents);
8631 }
8632
8633 /** Definition objects are nullable. Does this definition exist? */
8634 bool defined() const;
8635
8636 /** Is this an init definition; otherwise it's an update definition */
8637 bool is_init() const;
8638
8639 /** Pass an IRVisitor through to all Exprs referenced in the
8640 * definition. */
8641 void accept(IRVisitor *) const;
8642
8643 /** Pass an IRMutator through to all Exprs referenced in the
8644 * definition. */
8645 void mutate(IRMutator *);
8646
8647 /** Get the default (no-specialization) arguments (left-hand-side) of the definition */
8648 // @{
8649 const std::vector<Expr> &args() const;
8650 std::vector<Expr> &args();
8651 // @}
8652
8653 /** Get the default (no-specialization) right-hand-side of the definition */
8654 // @{
8655 const std::vector<Expr> &values() const;
8656 std::vector<Expr> &values();
8657 // @}
8658
8659 /** Get the predicate on the definition */
8660 // @{
8661 const Expr &predicate() const;
8662 Expr &predicate();
8663 // @}
8664
8665 /** Split predicate into vector of ANDs. If there is no predicate (i.e. this
8666 * definition is always valid), this returns an empty vector. */
8667 std::vector<Expr> split_predicate() const;
8668
8669 /** Get the default (no-specialization) stage-specific schedule associated
8670 * with this definition. */
8671 // @{
8672 const StageSchedule &schedule() const;
8673 StageSchedule &schedule();
8674 // @}
8675
8676 /** You may create several specialized versions of a func with
8677 * different stage-specific schedules. They trigger when the condition is
8678 * true. See \ref Func::specialize */
8679 // @{
8680 const std::vector<Specialization> &specializations() const;
8681 std::vector<Specialization> &specializations();
8682 const Specialization &add_specialization(Expr condition);
8683 // @}
8684
8685 /** Attempt to get the source file and line where this definition
8686 * was made using DWARF introspection. Returns an empty string if
8687 * no debug symbols were found or the debug symbols were not
8688 * understood. Works on OS X and Linux only. */
8689 std::string source_location() const;
8690};
8691
8692struct Specialization {
8693 Expr condition;
8694 Definition definition;
8695 std::string failure_message; // If non-empty, this specialization always assert-fails with this message.
8696};
8697
8698} // namespace Internal
8699} // namespace Halide
8700
8701#endif
8702#ifndef HALIDE_IR_VISITOR_H
8703#define HALIDE_IR_VISITOR_H
8704
8705#include <set>
8706
8707#ifndef HALIDE_IR_H
8708#define HALIDE_IR_H
8709
8710/** \file
8711 * Subtypes for Halide expressions (\ref Halide::Expr) and statements (\ref Halide::Internal::Stmt)
8712 */
8713
8714#include <string>
8715#include <vector>
8716
8717#ifndef HALIDE_BUFFER_H
8718#define HALIDE_BUFFER_H
8719
8720#ifndef HALIDE_DEVICE_INTERFACE_H
8721#define HALIDE_DEVICE_INTERFACE_H
8722
8723/** \file
8724 * Methods for managing device allocations when jitting
8725 */
8726
8727
8728namespace Halide {
8729
8730/** Gets the appropriate halide_device_interface_t * for a
8731 * DeviceAPI. If error_site is non-null, e.g. the name of the routine
8732 * calling get_device_interface_for_device_api, a user_error is
8733 * reported if the requested device API is not enabled in or supported
8734 * by the target, Halide has been compiled without this device API, or
8735 * the device API is None or Host or a bad value. The error_site
8736 * argument is printed in the error message. If error_site is null,
8737 * this routine returns nullptr instead of calling user_error. */
8738const halide_device_interface_t *get_device_interface_for_device_api(DeviceAPI d,
8739 const Target &t = get_jit_target_from_environment(),
8740 const char *error_site = nullptr);
8741
8742/** Get the specific DeviceAPI that Halide would select when presented
8743 * with DeviceAPI::Default_GPU for a given target. If no suitable api
8744 * is enabled in the target, returns DeviceAPI::Host. */
8745DeviceAPI get_default_device_api_for_target(const Target &t);
8746
8747/** This attempts to sniff whether a given Target (and its implied DeviceAPI) is usable on
8748 * the current host. If it appears to be usable, return true; if not, return false.
8749 * Note that a return value of true does *not* guarantee that future usage of
8750 * that device will succeed; it is intended mainly as a simple diagnostic
8751 * to allow early-exit when a desired device is definitely not usable.
8752 * Also note that this call is *NOT* threadsafe, as it temporarily redirect various
8753 * global error-handling hooks in Halide. */
8754bool host_supports_target_device(const Target &t);
8755
8756namespace Internal {
8757/** Get an Expr which evaluates to the device interface for the given device api at runtime. */
8758Expr make_device_interface_call(DeviceAPI device_api, MemoryType memory_type = MemoryType::Auto);
8759} // namespace Internal
8760
8761} // namespace Halide
8762
8763#endif
8764/** \file
8765 * Defines a Buffer type that wraps from halide_buffer_t and adds
8766 * functionality, and methods for more conveniently iterating over the
8767 * samples in a halide_buffer_t outside of Halide code. */
8768
8769#ifndef HALIDE_RUNTIME_BUFFER_H
8770#define HALIDE_RUNTIME_BUFFER_H
8771
8772#include <algorithm>
8773#include <atomic>
8774#include <cassert>
8775#include <cstdint>
8776#include <cstring>
8777#include <limits>
8778#include <memory>
8779#include <vector>
8780
8781#if defined(__has_feature)
8782#if __has_feature(memory_sanitizer)
8783#include <sanitizer/msan_interface.h>
8784#endif
8785#endif
8786
8787
8788#ifdef _MSC_VER
8789#include <malloc.h>
8790#define HALIDE_ALLOCA _alloca
8791#else
8792#define HALIDE_ALLOCA __builtin_alloca
8793#endif
8794
8795// gcc 5.1 has a false positive warning on this code
8796#if __GNUC__ == 5 && __GNUC_MINOR__ == 1
8797#pragma GCC diagnostic ignored "-Warray-bounds"
8798#endif
8799
8800namespace Halide {
8801namespace Runtime {
8802
8803// Forward-declare our Buffer class
8804template<typename T, int D>
8805class Buffer;
8806
8807// A helper to check if a parameter pack is entirely implicitly
8808// int-convertible to use with std::enable_if
8809template<typename... Args>
8810struct AllInts : std::false_type {};
8811
8812template<>
8813struct AllInts<> : std::true_type {};
8814
8815template<typename T, typename... Args>
8816struct AllInts<T, Args...> {
8817 static const bool value = std::is_convertible<T, int>::value && AllInts<Args...>::value;
8818};
8819
8820// Floats and doubles are technically implicitly int-convertible, but
8821// doing so produces a warning we treat as an error, so just disallow
8822// it here.
8823template<typename... Args>
8824struct AllInts<float, Args...> : std::false_type {};
8825
8826template<typename... Args>
8827struct AllInts<double, Args...> : std::false_type {};
8828
8829// A helper to detect if there are any zeros in a container
8830namespace Internal {
8831template<typename Container>
8832bool any_zero(const Container &c) {
8833 for (int i : c) {
8834 if (i == 0) {
8835 return true;
8836 }
8837 }
8838 return false;
8839}
8840} // namespace Internal
8841
8842/** A struct acting as a header for allocations owned by the Buffer
8843 * class itself. */
8844struct AllocationHeader {
8845 void (*deallocate_fn)(void *);
8846 std::atomic<int> ref_count;
8847
8848 // Note that ref_count always starts at 1
8849 AllocationHeader(void (*deallocate_fn)(void *))
8850 : deallocate_fn(deallocate_fn), ref_count(1) {
8851 }
8852};
8853
8854/** This indicates how to deallocate the device for a Halide::Runtime::Buffer. */
8855enum struct BufferDeviceOwnership : int {
8856 Allocated, ///> halide_device_free will be called when device ref count goes to zero
8857 WrappedNative, ///> halide_device_detach_native will be called when device ref count goes to zero
8858 Unmanaged, ///> No free routine will be called when device ref count goes to zero
8859 AllocatedDeviceAndHost, ///> Call device_and_host_free when DevRefCount goes to zero.
8860 Cropped, ///> Call halide_device_release_crop when DevRefCount goes to zero.
8861};
8862
8863/** A similar struct for managing device allocations. */
8864struct DeviceRefCount {
8865 // This is only ever constructed when there's something to manage,
8866 // so start at one.
8867 std::atomic<int> count{1};
8868 BufferDeviceOwnership ownership{BufferDeviceOwnership::Allocated};
8869};
8870
8871/** A templated Buffer class that wraps halide_buffer_t and adds
8872 * functionality. When using Halide from C++, this is the preferred
8873 * way to create input and output buffers. The overhead of using this
8874 * class relative to a naked halide_buffer_t is minimal - it uses another
8875 * ~16 bytes on the stack, and does no dynamic allocations when using
8876 * it to represent existing memory of a known maximum dimensionality.
8877 *
8878 * The template parameter T is the element type. For buffers where the
8879 * element type is unknown, or may vary, use void or const void.
8880 *
8881 * D is the maximum number of dimensions that can be represented using
8882 * space inside the class itself. Set it to the maximum dimensionality
8883 * you expect this buffer to be. If the actual dimensionality exceeds
8884 * this, heap storage is allocated to track the shape of the buffer. D
8885 * defaults to 4, which should cover nearly all usage.
8886 *
8887 * The class optionally allocates and owns memory for the image using
8888 * a shared pointer allocated with the provided allocator. If they are
8889 * null, malloc and free are used. Any device-side allocation is
8890 * considered as owned if and only if the host-side allocation is
8891 * owned. */
8892template<typename T = void, int D = 4>
8893class Buffer {
8894 /** The underlying halide_buffer_t */
8895 halide_buffer_t buf = {0};
8896
8897 /** Some in-class storage for shape of the dimensions. */
8898 halide_dimension_t shape[D];
8899
8900 /** The allocation owned by this Buffer. NULL if the Buffer does not
8901 * own the memory. */
8902 AllocationHeader *alloc = nullptr;
8903
8904 /** A reference count for the device allocation owned by this
8905 * buffer. */
8906 mutable DeviceRefCount *dev_ref_count = nullptr;
8907
8908 /** True if T is of type void or const void */
8909 static const bool T_is_void = std::is_same<typename std::remove_const<T>::type, void>::value;
8910
8911 /** A type function that adds a const qualifier if T is a const type. */
8912 template<typename T2>
8913 using add_const_if_T_is_const = typename std::conditional<std::is_const<T>::value, const T2, T2>::type;
8914
8915 /** T unless T is (const) void, in which case (const)
8916 * uint8_t. Useful for providing return types for operator() */
8917 using not_void_T = typename std::conditional<T_is_void,
8918 add_const_if_T_is_const<uint8_t>,
8919 T>::type;
8920
8921 /** T with constness removed. Useful for return type of copy(). */
8922 using not_const_T = typename std::remove_const<T>::type;
8923
8924 /** The type the elements are stored as. Equal to not_void_T
8925 * unless T is a pointer, in which case uint64_t. Halide stores
8926 * all pointer types as uint64s internally, even on 32-bit
8927 * systems. */
8928 using storage_T = typename std::conditional<std::is_pointer<T>::value, uint64_t, not_void_T>::type;
8929
8930public:
8931 /** True if the Halide type is not void (or const void). */
8932 static constexpr bool has_static_halide_type = !T_is_void;
8933
8934 /** Get the Halide type of T. Callers should not use the result if
8935 * has_static_halide_type is false. */
8936 static halide_type_t static_halide_type() {
8937 return halide_type_of<typename std::remove_cv<not_void_T>::type>();
8938 }
8939
8940 /** Does this Buffer own the host memory it refers to? */
8941 bool owns_host_memory() const {
8942 return alloc != nullptr;
8943 }
8944
8945private:
8946 /** Increment the reference count of any owned allocation */
8947 void incref() const {
8948 if (owns_host_memory()) {
8949 alloc->ref_count++;
8950 }
8951 if (buf.device) {
8952 if (!dev_ref_count) {
8953 // I seem to have a non-zero dev field but no
8954 // reference count for it. I must have been given a
8955 // device allocation by a Halide pipeline, and have
8956 // never been copied from since. Take sole ownership
8957 // of it.
8958 dev_ref_count = new DeviceRefCount;
8959 }
8960 dev_ref_count->count++;
8961 }
8962 }
8963
8964 // Note that this is called "cropped" but can also encompass a slice/embed
8965 // operation as well.
8966 struct DevRefCountCropped : DeviceRefCount {
8967 Buffer<T, D> cropped_from;
8968 DevRefCountCropped(const Buffer<T, D> &cropped_from)
8969 : cropped_from(cropped_from) {
8970 ownership = BufferDeviceOwnership::Cropped;
8971 }
8972 };
8973
8974 /** Setup the device ref count for a buffer to indicate it is a crop (or slice, embed, etc) of cropped_from */
8975 void crop_from(const Buffer<T, D> &cropped_from) {
8976 assert(dev_ref_count == nullptr);
8977 dev_ref_count = new DevRefCountCropped(cropped_from);
8978 }
8979
8980 /** Decrement the reference count of any owned allocation and free host
8981 * and device memory if it hits zero. Sets alloc to nullptr. */
8982 void decref(bool device_only = false) {
8983 if (owns_host_memory() && !device_only) {
8984 int new_count = --(alloc->ref_count);
8985 if (new_count == 0) {
8986 void (*fn)(void *) = alloc->deallocate_fn;
8987 alloc->~AllocationHeader();
8988 fn(alloc);
8989 }
8990 buf.host = nullptr;
8991 alloc = nullptr;
8992 set_host_dirty(false);
8993 }
8994 int new_count = 0;
8995 if (dev_ref_count) {
8996 new_count = --(dev_ref_count->count);
8997 }
8998 if (new_count == 0) {
8999 if (buf.device) {
9000 assert(!(alloc && device_dirty()) &&
9001 "Implicitly freeing a dirty device allocation while a host allocation still lives. "
9002 "Call device_free explicitly if you want to drop dirty device-side data. "
9003 "Call copy_to_host explicitly if you want the data copied to the host allocation "
9004 "before the device allocation is freed.");
9005 if (dev_ref_count && dev_ref_count->ownership == BufferDeviceOwnership::WrappedNative) {
9006 buf.device_interface->detach_native(nullptr, &buf);
9007 } else if (dev_ref_count && dev_ref_count->ownership == BufferDeviceOwnership::AllocatedDeviceAndHost) {
9008 buf.device_interface->device_and_host_free(nullptr, &buf);
9009 } else if (dev_ref_count && dev_ref_count->ownership == BufferDeviceOwnership::Cropped) {
9010 buf.device_interface->device_release_crop(nullptr, &buf);
9011 } else if (dev_ref_count == nullptr || dev_ref_count->ownership == BufferDeviceOwnership::Allocated) {
9012 buf.device_interface->device_free(nullptr, &buf);
9013 }
9014 }
9015 if (dev_ref_count) {
9016 if (dev_ref_count->ownership == BufferDeviceOwnership::Cropped) {
9017 delete (DevRefCountCropped *)dev_ref_count;
9018 } else {
9019 delete dev_ref_count;
9020 }
9021 }
9022 }
9023 dev_ref_count = nullptr;
9024 buf.device = 0;
9025 buf.device_interface = nullptr;
9026 }
9027
9028 void free_shape_storage() {
9029 if (buf.dim != shape) {
9030 delete[] buf.dim;
9031 buf.dim = nullptr;
9032 }
9033 }
9034
9035 void make_shape_storage(const int dimensions) {
9036 // This should usually be inlined, so if dimensions is statically known,
9037 // we can skip the call to new
9038 buf.dimensions = dimensions;
9039 buf.dim = (dimensions <= D) ? shape : new halide_dimension_t[dimensions];
9040 }
9041
9042 void copy_shape_from(const halide_buffer_t &other) {
9043 // All callers of this ensure that buf.dimensions == other.dimensions.
9044 make_shape_storage(other.dimensions);
9045 std::copy(other.dim, other.dim + other.dimensions, buf.dim);
9046 }
9047
9048 template<typename T2, int D2>
9049 void move_shape_from(Buffer<T2, D2> &&other) {
9050 if (other.shape == other.buf.dim) {
9051 copy_shape_from(other.buf);
9052 } else {
9053 buf.dim = other.buf.dim;
9054 other.buf.dim = nullptr;
9055 }
9056 }
9057
9058 /** Initialize the shape from a halide_buffer_t. */
9059 void initialize_from_buffer(const halide_buffer_t &b,
9060 BufferDeviceOwnership ownership) {
9061 memcpy(&buf, &b, sizeof(halide_buffer_t));
9062 copy_shape_from(b);
9063 if (b.device) {
9064 dev_ref_count = new DeviceRefCount;
9065 dev_ref_count->ownership = ownership;
9066 }
9067 }
9068
9069 /** Initialize the shape from an array of ints */
9070 void initialize_shape(const int *sizes) {
9071 for (int i = 0; i < buf.dimensions; i++) {
9072 buf.dim[i].min = 0;
9073 buf.dim[i].extent = sizes[i];
9074 if (i == 0) {
9075 buf.dim[i].stride = 1;
9076 } else {
9077 buf.dim[i].stride = buf.dim[i - 1].stride * buf.dim[i - 1].extent;
9078 }
9079 }
9080 }
9081
9082 /** Initialize the shape from a vector of extents */
9083 void initialize_shape(const std::vector<int> &sizes) {
9084 assert(buf.dimensions == (int)sizes.size());
9085 initialize_shape(sizes.data());
9086 }
9087
9088 /** Initialize the shape from the static shape of an array */
9089 template<typename Array, size_t N>
9090 void initialize_shape_from_array_shape(int next, Array (&vals)[N]) {
9091 buf.dim[next].min = 0;
9092 buf.dim[next].extent = (int)N;
9093 if (next == 0) {
9094 buf.dim[next].stride = 1;
9095 } else {
9096 initialize_shape_from_array_shape(next - 1, vals[0]);
9097 buf.dim[next].stride = buf.dim[next - 1].stride * buf.dim[next - 1].extent;
9098 }
9099 }
9100
9101 /** Base case for the template recursion above. */
9102 template<typename T2>
9103 void initialize_shape_from_array_shape(int, const T2 &) {
9104 }
9105
9106 /** Get the dimensionality of a multi-dimensional C array */
9107 template<typename Array, size_t N>
9108 static int dimensionality_of_array(Array (&vals)[N]) {
9109 return dimensionality_of_array(vals[0]) + 1;
9110 }
9111
9112 template<typename T2>
9113 static int dimensionality_of_array(const T2 &) {
9114 return 0;
9115 }
9116
9117 /** Get the underlying halide_type_t of an array's element type. */
9118 template<typename Array, size_t N>
9119 static halide_type_t scalar_type_of_array(Array (&vals)[N]) {
9120 return scalar_type_of_array(vals[0]);
9121 }
9122
9123 template<typename T2>
9124 static halide_type_t scalar_type_of_array(const T2 &) {
9125 return halide_type_of<typename std::remove_cv<T2>::type>();
9126 }
9127
9128 /** Crop a single dimension without handling device allocation. */
9129 void crop_host(int d, int min, int extent) {
9130 assert(dim(d).min() <= min);
9131 assert(dim(d).max() >= min + extent - 1);
9132 ptrdiff_t shift = min - dim(d).min();
9133 if (buf.host != nullptr) {
9134 buf.host += (shift * dim(d).stride()) * type().bytes();
9135 }
9136 buf.dim[d].min = min;
9137 buf.dim[d].extent = extent;
9138 }
9139
9140 /** Crop as many dimensions as are in rect, without handling device allocation. */
9141 void crop_host(const std::vector<std::pair<int, int>> &rect) {
9142 assert(rect.size() <= static_cast<decltype(rect.size())>(std::numeric_limits<int>::max()));
9143 int limit = (int)rect.size();
9144 assert(limit <= dimensions());
9145 for (int i = 0; i < limit; i++) {
9146 crop_host(i, rect[i].first, rect[i].second);
9147 }
9148 }
9149
9150 void complete_device_crop(Buffer<T, D> &result_host_cropped) const {
9151 assert(buf.device_interface != nullptr);
9152 if (buf.device_interface->device_crop(nullptr, &this->buf, &result_host_cropped.buf) == 0) {
9153 const Buffer<T, D> *cropped_from = this;
9154 // TODO: Figure out what to do if dev_ref_count is nullptr. Should incref logic run here?
9155 // is it possible to get to this point without incref having run at least once since
9156 // the device field was set? (I.e. in the internal logic of crop. incref might have been
9157 // called.)
9158 if (dev_ref_count != nullptr && dev_ref_count->ownership == BufferDeviceOwnership::Cropped) {
9159 cropped_from = &((DevRefCountCropped *)dev_ref_count)->cropped_from;
9160 }
9161 result_host_cropped.crop_from(*cropped_from);
9162 }
9163 }
9164
9165 /** slice a single dimension without handling device allocation. */
9166 void slice_host(int d, int pos) {
9167 assert(d >= 0 && d < dimensions());
9168 assert(pos >= dim(d).min() && pos <= dim(d).max());
9169 buf.dimensions--;
9170 ptrdiff_t shift = pos - buf.dim[d].min;
9171 if (buf.host != nullptr) {
9172 buf.host += (shift * buf.dim[d].stride) * type().bytes();
9173 }
9174 for (int i = d; i < buf.dimensions; i++) {
9175 buf.dim[i] = buf.dim[i + 1];
9176 }
9177 buf.dim[buf.dimensions] = {0, 0, 0};
9178 }
9179
9180 void complete_device_slice(Buffer<T, D> &result_host_sliced, int d, int pos) const {
9181 assert(buf.device_interface != nullptr);
9182 if (buf.device_interface->device_slice(nullptr, &this->buf, d, pos, &result_host_sliced.buf) == 0) {
9183 const Buffer<T, D> *sliced_from = this;
9184 // TODO: Figure out what to do if dev_ref_count is nullptr. Should incref logic run here?
9185 // is it possible to get to this point without incref having run at least once since
9186 // the device field was set? (I.e. in the internal logic of slice. incref might have been
9187 // called.)
9188 if (dev_ref_count != nullptr && dev_ref_count->ownership == BufferDeviceOwnership::Cropped) {
9189 sliced_from = &((DevRefCountCropped *)dev_ref_count)->cropped_from;
9190 }
9191 // crop_from() is correct here, despite the fact that we are slicing.
9192 result_host_sliced.crop_from(*sliced_from);
9193 }
9194 }
9195
9196public:
9197 typedef T ElemType;
9198
9199 /** Read-only access to the shape */
9200 class Dimension {
9201 const halide_dimension_t &d;
9202
9203 public:
9204 /** The lowest coordinate in this dimension */
9205 HALIDE_ALWAYS_INLINE int min() const {
9206 return d.min;
9207 }
9208
9209 /** The number of elements in memory you have to step over to
9210 * increment this coordinate by one. */
9211 HALIDE_ALWAYS_INLINE int stride() const {
9212 return d.stride;
9213 }
9214
9215 /** The extent of the image along this dimension */
9216 HALIDE_ALWAYS_INLINE int extent() const {
9217 return d.extent;
9218 }
9219
9220 /** The highest coordinate in this dimension */
9221 HALIDE_ALWAYS_INLINE int max() const {
9222 return min() + extent() - 1;
9223 }
9224
9225 /** An iterator class, so that you can iterate over
9226 * coordinates in a dimensions using a range-based for loop. */
9227 struct iterator {
9228 int val;
9229 int operator*() const {
9230 return val;
9231 }
9232 bool operator!=(const iterator &other) const {
9233 return val != other.val;
9234 }
9235 iterator &operator++() {
9236 val++;
9237 return *this;
9238 }
9239 };
9240
9241 /** An iterator that points to the min coordinate */
9242 HALIDE_ALWAYS_INLINE iterator begin() const {
9243 return {min()};
9244 }
9245
9246 /** An iterator that points to one past the max coordinate */
9247 HALIDE_ALWAYS_INLINE iterator end() const {
9248 return {min() + extent()};
9249 }
9250
9251 Dimension(const halide_dimension_t &dim)
9252 : d(dim) {
9253 }
9254 };
9255
9256 /** Access the shape of the buffer */
9257 HALIDE_ALWAYS_INLINE Dimension dim(int i) const {
9258 assert(i >= 0 && i < this->dimensions());
9259 return Dimension(buf.dim[i]);
9260 }
9261
9262 /** Access to the mins, strides, extents. Will be deprecated. Do not use. */
9263 // @{
9264 int min(int i) const {
9265 return dim(i).min();
9266 }
9267 int extent(int i) const {
9268 return dim(i).extent();
9269 }
9270 int stride(int i) const {
9271 return dim(i).stride();
9272 }
9273 // @}
9274
9275 /** The total number of elements this buffer represents. Equal to
9276 * the product of the extents */
9277 size_t number_of_elements() const {
9278 return buf.number_of_elements();
9279 }
9280
9281 /** Get the dimensionality of the buffer. */
9282 int dimensions() const {
9283 return buf.dimensions;
9284 }
9285
9286 /** Get the type of the elements. */
9287 halide_type_t type() const {
9288 return buf.type;
9289 }
9290
9291 /** A pointer to the element with the lowest address. If all
9292 * strides are positive, equal to the host pointer. */
9293 T *begin() const {
9294 assert(buf.host != nullptr); // Cannot call begin() on an unallocated Buffer.
9295 return (T *)buf.begin();
9296 }
9297
9298 /** A pointer to one beyond the element with the highest address. */
9299 T *end() const {
9300 assert(buf.host != nullptr); // Cannot call end() on an unallocated Buffer.
9301 return (T *)buf.end();
9302 }
9303
9304 /** The total number of bytes spanned by the data in memory. */
9305 size_t size_in_bytes() const {
9306 return buf.size_in_bytes();
9307 }
9308
9309 /** Reset the Buffer to be equivalent to a default-constructed Buffer
9310 * of the same static type (if any); Buffer<void> will have its runtime
9311 * type reset to uint8. */
9312 void reset() {
9313 *this = Buffer();
9314 }
9315
9316 Buffer()
9317 : shape() {
9318 buf.type = static_halide_type();
9319 make_shape_storage(0);
9320 }
9321
9322 /** Make a Buffer from a halide_buffer_t */
9323 explicit Buffer(const halide_buffer_t &buf,
9324 BufferDeviceOwnership ownership = BufferDeviceOwnership::Unmanaged) {
9325 assert(T_is_void || buf.type == static_halide_type());
9326 initialize_from_buffer(buf, ownership);
9327 }
9328
9329 /** Give Buffers access to the members of Buffers of different dimensionalities and types. */
9330 template<typename T2, int D2>
9331 friend class Buffer;
9332
9333private:
9334 template<typename T2, int D2>
9335 static void static_assert_can_convert_from() {
9336 static_assert((!std::is_const<T2>::value || std::is_const<T>::value),
9337 "Can't convert from a Buffer<const T> to a Buffer<T>");
9338 static_assert(std::is_same<typename std::remove_const<T>::type,
9339 typename std::remove_const<T2>::type>::value ||
9340 T_is_void || Buffer<T2, D2>::T_is_void,
9341 "type mismatch constructing Buffer");
9342 }
9343
9344public:
9345 /** Determine if if an Buffer<T, D> can be constructed from some other Buffer type.
9346 * If this can be determined at compile time, fail with a static assert; otherwise
9347 * return a boolean based on runtime typing. */
9348 template<typename T2, int D2>
9349 static bool can_convert_from(const Buffer<T2, D2> &other) {
9350 static_assert_can_convert_from<T2, D2>();
9351 if (Buffer<T2, D2>::T_is_void && !T_is_void) {
9352 return other.type() == static_halide_type();
9353 }
9354 return true;
9355 }
9356
9357 /** Fail an assertion at runtime or compile-time if an Buffer<T, D>
9358 * cannot be constructed from some other Buffer type. */
9359 template<typename T2, int D2>
9360 static void assert_can_convert_from(const Buffer<T2, D2> &other) {
9361 // Explicitly call static_assert_can_convert_from() here so
9362 // that we always get compile-time checking, even if compiling with
9363 // assertions disabled.
9364 static_assert_can_convert_from<T2, D2>();
9365 assert(can_convert_from(other));
9366 }
9367
9368 /** Copy constructor. Does not copy underlying data. */
9369 Buffer(const Buffer<T, D> &other)
9370 : buf(other.buf),
9371 alloc(other.alloc) {
9372 other.incref();
9373 dev_ref_count = other.dev_ref_count;
9374 copy_shape_from(other.buf);
9375 }
9376
9377 /** Construct a Buffer from a Buffer of different dimensionality
9378 * and type. Asserts that the type matches (at runtime, if one of
9379 * the types is void). Note that this constructor is
9380 * implicit. This, for example, lets you pass things like
9381 * Buffer<T> or Buffer<const void> to functions expected
9382 * Buffer<const T>. */
9383 template<typename T2, int D2>
9384 Buffer(const Buffer<T2, D2> &other)
9385 : buf(other.buf),
9386 alloc(other.alloc) {
9387 assert_can_convert_from(other);
9388 other.incref();
9389 dev_ref_count = other.dev_ref_count;
9390 copy_shape_from(other.buf);
9391 }
9392
9393 /** Move constructor */
9394 Buffer(Buffer<T, D> &&other) noexcept
9395 : buf(other.buf),
9396 alloc(other.alloc),
9397 dev_ref_count(other.dev_ref_count) {
9398 other.dev_ref_count = nullptr;
9399 other.alloc = nullptr;
9400 move_shape_from(std::forward<Buffer<T, D>>(other));
9401 other.buf = halide_buffer_t();
9402 }
9403
9404 /** Move-construct a Buffer from a Buffer of different
9405 * dimensionality and type. Asserts that the types match (at
9406 * runtime if one of the types is void). */
9407 template<typename T2, int D2>
9408 Buffer(Buffer<T2, D2> &&other)
9409 : buf(other.buf),
9410 alloc(other.alloc),
9411 dev_ref_count(other.dev_ref_count) {
9412 assert_can_convert_from(other);
9413 other.dev_ref_count = nullptr;
9414 other.alloc = nullptr;
9415 move_shape_from(std::forward<Buffer<T2, D2>>(other));
9416 other.buf = halide_buffer_t();
9417 }
9418
9419 /** Assign from another Buffer of possibly-different
9420 * dimensionality and type. Asserts that the types match (at
9421 * runtime if one of the types is void). */
9422 template<typename T2, int D2>
9423 Buffer<T, D> &operator=(const Buffer<T2, D2> &other) {
9424 if ((const void *)this == (const void *)&other) {
9425 return *this;
9426 }
9427 assert_can_convert_from(other);
9428 other.incref();
9429 decref();
9430 dev_ref_count = other.dev_ref_count;
9431 alloc = other.alloc;
9432 free_shape_storage();
9433 buf = other.buf;
9434 copy_shape_from(other.buf);
9435 return *this;
9436 }
9437
9438 /** Standard assignment operator */
9439 Buffer<T, D> &operator=(const Buffer<T, D> &other) {
9440 // The cast to void* here is just to satisfy clang-tidy
9441 if ((const void *)this == (const void *)&other) {
9442 return *this;
9443 }
9444 other.incref();
9445 decref();
9446 dev_ref_count = other.dev_ref_count;
9447 alloc = other.alloc;
9448 free_shape_storage();
9449 buf = other.buf;
9450 copy_shape_from(other.buf);
9451 return *this;
9452 }
9453
9454 /** Move from another Buffer of possibly-different
9455 * dimensionality and type. Asserts that the types match (at
9456 * runtime if one of the types is void). */
9457 template<typename T2, int D2>
9458 Buffer<T, D> &operator=(Buffer<T2, D2> &&other) {
9459 assert_can_convert_from(other);
9460 decref();
9461 alloc = other.alloc;
9462 other.alloc = nullptr;
9463 dev_ref_count = other.dev_ref_count;
9464 other.dev_ref_count = nullptr;
9465 free_shape_storage();
9466 buf = other.buf;
9467 move_shape_from(std::forward<Buffer<T2, D2>>(other));
9468 other.buf = halide_buffer_t();
9469 return *this;
9470 }
9471
9472 /** Standard move-assignment operator */
9473 Buffer<T, D> &operator=(Buffer<T, D> &&other) noexcept {
9474 decref();
9475 alloc = other.alloc;
9476 other.alloc = nullptr;
9477 dev_ref_count = other.dev_ref_count;
9478 other.dev_ref_count = nullptr;
9479 free_shape_storage();
9480 buf = other.buf;
9481 move_shape_from(std::forward<Buffer<T, D>>(other));
9482 other.buf = halide_buffer_t();
9483 return *this;
9484 }
9485
9486 /** Check the product of the extents fits in memory. */
9487 void check_overflow() {
9488 size_t size = type().bytes();
9489 for (int i = 0; i < dimensions(); i++) {
9490 size *= dim(i).extent();
9491 }
9492 // We allow 2^31 or 2^63 bytes, so drop the top bit.
9493 size = (size << 1) >> 1;
9494 for (int i = 0; i < dimensions(); i++) {
9495 size /= dim(i).extent();
9496 }
9497 assert(size == (size_t)type().bytes() && "Error: Overflow computing total size of buffer.");
9498 }
9499
9500 /** Allocate memory for this Buffer. Drops the reference to any
9501 * owned memory. */
9502 void allocate(void *(*allocate_fn)(size_t) = nullptr,
9503 void (*deallocate_fn)(void *) = nullptr) {
9504 if (!allocate_fn) {
9505 allocate_fn = malloc;
9506 }
9507 if (!deallocate_fn) {
9508 deallocate_fn = free;
9509 }
9510
9511 // Drop any existing allocation
9512 deallocate();
9513
9514 // Conservatively align images to 128 bytes. This is enough
9515 // alignment for all the platforms we might use.
9516 size_t size = size_in_bytes();
9517 const size_t alignment = 128;
9518 size = (size + alignment - 1) & ~(alignment - 1);
9519 void *alloc_storage = allocate_fn(size + sizeof(AllocationHeader) + alignment - 1);
9520 alloc = new (alloc_storage) AllocationHeader(deallocate_fn);
9521 uint8_t *unaligned_ptr = ((uint8_t *)alloc) + sizeof(AllocationHeader);
9522 buf.host = (uint8_t *)((uintptr_t)(unaligned_ptr + alignment - 1) & ~(alignment - 1));
9523 }
9524
9525 /** Drop reference to any owned host or device memory, possibly
9526 * freeing it, if this buffer held the last reference to
9527 * it. Retains the shape of the buffer. Does nothing if this
9528 * buffer did not allocate its own memory. */
9529 void deallocate() {
9530 decref();
9531 }
9532
9533 /** Drop reference to any owned device memory, possibly freeing it
9534 * if this buffer held the last reference to it. Asserts that
9535 * device_dirty is false. */
9536 void device_deallocate() {
9537 decref(true);
9538 }
9539
9540 /** Allocate a new image of the given size with a runtime
9541 * type. Only used when you do know what size you want but you
9542 * don't know statically what type the elements are. Pass zeroes
9543 * to make a buffer suitable for bounds query calls. */
9544 template<typename... Args,
9545 typename = typename std::enable_if<AllInts<Args...>::value>::type>
9546 Buffer(halide_type_t t, int first, Args... rest) {
9547 if (!T_is_void) {
9548 assert(static_halide_type() == t);
9549 }
9550 int extents[] = {first, (int)rest...};
9551 buf.type = t;
9552 constexpr int buf_dimensions = 1 + (int)(sizeof...(rest));
9553 make_shape_storage(buf_dimensions);
9554 initialize_shape(extents);
9555 if (!Internal::any_zero(extents)) {
9556 check_overflow();
9557 allocate();
9558 }
9559 }
9560
9561 /** Allocate a new image of the given size. Pass zeroes to make a
9562 * buffer suitable for bounds query calls. */
9563 // @{
9564
9565 // The overload with one argument is 'explicit', so that
9566 // (say) int is not implicitly convertible to Buffer<int>
9567 explicit Buffer(int first) {
9568 static_assert(!T_is_void,
9569 "To construct an Buffer<void>, pass a halide_type_t as the first argument to the constructor");
9570 int extents[] = {first};
9571 buf.type = static_halide_type();
9572 constexpr int buf_dimensions = 1;
9573 make_shape_storage(buf_dimensions);
9574 initialize_shape(extents);
9575 if (first != 0) {
9576 check_overflow();
9577 allocate();
9578 }
9579 }
9580
9581 template<typename... Args,
9582 typename = typename std::enable_if<AllInts<Args...>::value>::type>
9583 Buffer(int first, int second, Args... rest) {
9584 static_assert(!T_is_void,
9585 "To construct an Buffer<void>, pass a halide_type_t as the first argument to the constructor");
9586 int extents[] = {first, second, (int)rest...};
9587 buf.type = static_halide_type();
9588 constexpr int buf_dimensions = 2 + (int)(sizeof...(rest));
9589 make_shape_storage(buf_dimensions);
9590 initialize_shape(extents);
9591 if (!Internal::any_zero(extents)) {
9592 check_overflow();
9593 allocate();
9594 }
9595 }
9596 // @}
9597
9598 /** Allocate a new image of unknown type using a vector of ints as the size. */
9599 Buffer(halide_type_t t, const std::vector<int> &sizes) {
9600 if (!T_is_void) {
9601 assert(static_halide_type() == t);
9602 }
9603 buf.type = t;
9604 make_shape_storage((int)sizes.size());
9605 initialize_shape(sizes);
9606 if (!Internal::any_zero(sizes)) {
9607 check_overflow();
9608 allocate();
9609 }
9610 }
9611
9612 /** Allocate a new image of known type using a vector of ints as the size. */
9613 explicit Buffer(const std::vector<int> &sizes)
9614 : Buffer(static_halide_type(), sizes) {
9615 }
9616
9617private:
9618 // Create a copy of the sizes vector, ordered as specified by order.
9619 static std::vector<int> make_ordered_sizes(const std::vector<int> &sizes, const std::vector<int> &order) {
9620 assert(order.size() == sizes.size());
9621 std::vector<int> ordered_sizes(sizes.size());
9622 for (size_t i = 0; i < sizes.size(); ++i) {
9623 ordered_sizes[i] = sizes.at(order[i]);
9624 }
9625 return ordered_sizes;
9626 }
9627
9628public:
9629 /** Allocate a new image of unknown type using a vector of ints as the size and
9630 * a vector of indices indicating the storage order for each dimension. The
9631 * length of the sizes vector and the storage-order vector must match. For instance,
9632 * to allocate an interleaved RGB buffer, you would pass {2, 0, 1} for storage_order. */
9633 Buffer(halide_type_t t, const std::vector<int> &sizes, const std::vector<int> &storage_order)
9634 : Buffer(t, make_ordered_sizes(sizes, storage_order)) {
9635 transpose(storage_order);
9636 }
9637
9638 Buffer(const std::vector<int> &sizes, const std::vector<int> &storage_order)
9639 : Buffer(static_halide_type(), sizes, storage_order) {
9640 }
9641
9642 /** Make an Buffer that refers to a statically sized array. Does not
9643 * take ownership of the data, and does not set the host_dirty flag. */
9644 template<typename Array, size_t N>
9645 explicit Buffer(Array (&vals)[N]) {
9646 const int buf_dimensions = dimensionality_of_array(vals);
9647 buf.type = scalar_type_of_array(vals);
9648 buf.host = (uint8_t *)vals;
9649 make_shape_storage(buf_dimensions);
9650 initialize_shape_from_array_shape(buf.dimensions - 1, vals);
9651 }
9652
9653 /** Initialize an Buffer of runtime type from a pointer and some
9654 * sizes. Assumes dense row-major packing and a min coordinate of
9655 * zero. Does not take ownership of the data and does not set the
9656 * host_dirty flag. */
9657 template<typename... Args,
9658 typename = typename std::enable_if<AllInts<Args...>::value>::type>
9659 explicit Buffer(halide_type_t t, add_const_if_T_is_const<void> *data, int first, Args &&...rest) {
9660 if (!T_is_void) {
9661 assert(static_halide_type() == t);
9662 }
9663 int extents[] = {first, (int)rest...};
9664 buf.type = t;
9665 constexpr int buf_dimensions = 1 + (int)(sizeof...(rest));
9666 buf.host = (uint8_t *)const_cast<void *>(data);
9667 make_shape_storage(buf_dimensions);
9668 initialize_shape(extents);
9669 }
9670
9671 /** Initialize an Buffer from a pointer and some sizes. Assumes
9672 * dense row-major packing and a min coordinate of zero. Does not
9673 * take ownership of the data and does not set the host_dirty flag. */
9674 template<typename... Args,
9675 typename = typename std::enable_if<AllInts<Args...>::value>::type>
9676 explicit Buffer(T *data, int first, Args &&...rest) {
9677 int extents[] = {first, (int)rest...};
9678 buf.type = static_halide_type();
9679 constexpr int buf_dimensions = 1 + (int)(sizeof...(rest));
9680 buf.host = (uint8_t *)const_cast<typename std::remove_const<T>::type *>(data);
9681 make_shape_storage(buf_dimensions);
9682 initialize_shape(extents);
9683 }
9684
9685 /** Initialize an Buffer from a pointer and a vector of
9686 * sizes. Assumes dense row-major packing and a min coordinate of
9687 * zero. Does not take ownership of the data and does not set the
9688 * host_dirty flag. */
9689 explicit Buffer(T *data, const std::vector<int> &sizes) {
9690 buf.type = static_halide_type();
9691 buf.host = (uint8_t *)const_cast<typename std::remove_const<T>::type *>(data);
9692 make_shape_storage((int)sizes.size());
9693 initialize_shape(sizes);
9694 }
9695
9696 /** Initialize an Buffer of runtime type from a pointer and a
9697 * vector of sizes. Assumes dense row-major packing and a min
9698 * coordinate of zero. Does not take ownership of the data and
9699 * does not set the host_dirty flag. */
9700 explicit Buffer(halide_type_t t, add_const_if_T_is_const<void> *data, const std::vector<int> &sizes) {
9701 if (!T_is_void) {
9702 assert(static_halide_type() == t);
9703 }
9704 buf.type = t;
9705 buf.host = (uint8_t *)const_cast<void *>(data);
9706 make_shape_storage((int)sizes.size());
9707 initialize_shape(sizes);
9708 }
9709
9710 /** Initialize an Buffer from a pointer to the min coordinate and
9711 * an array describing the shape. Does not take ownership of the
9712 * data, and does not set the host_dirty flag. */
9713 explicit Buffer(halide_type_t t, add_const_if_T_is_const<void> *data, int d, const halide_dimension_t *shape) {
9714 if (!T_is_void) {
9715 assert(static_halide_type() == t);
9716 }
9717 buf.type = t;
9718 buf.host = (uint8_t *)const_cast<void *>(data);
9719 make_shape_storage(d);
9720 for (int i = 0; i < d; i++) {
9721 buf.dim[i] = shape[i];
9722 }
9723 }
9724
9725 /** Initialize a Buffer from a pointer to the min coordinate and
9726 * a vector describing the shape. Does not take ownership of the
9727 * data, and does not set the host_dirty flag. */
9728 explicit inline Buffer(halide_type_t t, add_const_if_T_is_const<void> *data,
9729 const std::vector<halide_dimension_t> &shape)
9730 : Buffer(t, data, (int)shape.size(), shape.data()) {
9731 }
9732
9733 /** Initialize an Buffer from a pointer to the min coordinate and
9734 * an array describing the shape. Does not take ownership of the
9735 * data and does not set the host_dirty flag. */
9736 explicit Buffer(T *data, int d, const halide_dimension_t *shape) {
9737 buf.type = static_halide_type();
9738 buf.host = (uint8_t *)const_cast<typename std::remove_const<T>::type *>(data);
9739 make_shape_storage(d);
9740 for (int i = 0; i < d; i++) {
9741 buf.dim[i] = shape[i];
9742 }
9743 }
9744
9745 /** Initialize a Buffer from a pointer to the min coordinate and
9746 * a vector describing the shape. Does not take ownership of the
9747 * data, and does not set the host_dirty flag. */
9748 explicit inline Buffer(T *data, const std::vector<halide_dimension_t> &shape)
9749 : Buffer(data, (int)shape.size(), shape.data()) {
9750 }
9751
9752 /** Destructor. Will release any underlying owned allocation if
9753 * this is the last reference to it. Will assert fail if there are
9754 * weak references to this Buffer outstanding. */
9755 ~Buffer() {
9756 free_shape_storage();
9757 decref();
9758 }
9759
9760 /** Get a pointer to the raw halide_buffer_t this wraps. */
9761 // @{
9762 halide_buffer_t *raw_buffer() {
9763 return &buf;
9764 }
9765
9766 const halide_buffer_t *raw_buffer() const {
9767 return &buf;
9768 }
9769 // @}
9770
9771 /** Provide a cast operator to halide_buffer_t *, so that
9772 * instances can be passed directly to Halide filters. */
9773 operator halide_buffer_t *() {
9774 return &buf;
9775 }
9776
9777 /** Return a typed reference to this Buffer. Useful for converting
9778 * a reference to a Buffer<void> to a reference to, for example, a
9779 * Buffer<const uint8_t>, or converting a Buffer<T>& to Buffer<const T>&.
9780 * Does a runtime assert if the source buffer type is void. */
9781 template<typename T2>
9782 HALIDE_ALWAYS_INLINE Buffer<T2, D> &as() & {
9783 Buffer<T2, D>::assert_can_convert_from(*this);
9784 return *((Buffer<T2, D> *)this);
9785 }
9786
9787 /** Return a const typed reference to this Buffer. Useful for
9788 * converting a conference reference to one Buffer type to a const
9789 * reference to another Buffer type. Does a runtime assert if the
9790 * source buffer type is void. */
9791 template<typename T2>
9792 HALIDE_ALWAYS_INLINE const Buffer<T2, D> &as() const & {
9793 Buffer<T2, D>::assert_can_convert_from(*this);
9794 return *((const Buffer<T2, D> *)this);
9795 }
9796
9797 /** Returns this rval Buffer with a different type attached. Does
9798 * a dynamic type check if the source type is void. */
9799 template<typename T2>
9800 HALIDE_ALWAYS_INLINE Buffer<T2, D> as() && {
9801 Buffer<T2, D>::assert_can_convert_from(*this);
9802 return *((Buffer<T2, D> *)this);
9803 }
9804
9805 /** as_const() is syntactic sugar for .as<const T>(), to avoid the need
9806 * to recapitulate the type argument. */
9807 // @{
9808 HALIDE_ALWAYS_INLINE
9809 Buffer<typename std::add_const<T>::type, D> &as_const() & {
9810 // Note that we can skip the assert_can_convert_from(), since T -> const T
9811 // conversion is always legal.
9812 return *((Buffer<typename std::add_const<T>::type, D> *)this);
9813 }
9814
9815 HALIDE_ALWAYS_INLINE
9816 const Buffer<typename std::add_const<T>::type, D> &as_const() const & {
9817 return *((const Buffer<typename std::add_const<T>::type, D> *)this);
9818 }
9819
9820 HALIDE_ALWAYS_INLINE
9821 Buffer<typename std::add_const<T>::type, D> as_const() && {
9822 return *((Buffer<typename std::add_const<T>::type, D> *)this);
9823 }
9824 // @}
9825
9826 /** Conventional names for the first three dimensions. */
9827 // @{
9828 int width() const {
9829 return (dimensions() > 0) ? dim(0).extent() : 1;
9830 }
9831 int height() const {
9832 return (dimensions() > 1) ? dim(1).extent() : 1;
9833 }
9834 int channels() const {
9835 return (dimensions() > 2) ? dim(2).extent() : 1;
9836 }
9837 // @}
9838
9839 /** Conventional names for the min and max value of each dimension */
9840 // @{
9841 int left() const {
9842 return dim(0).min();
9843 }
9844
9845 int right() const {
9846 return dim(0).max();
9847 }
9848
9849 int top() const {
9850 return dim(1).min();
9851 }
9852
9853 int bottom() const {
9854 return dim(1).max();
9855 }
9856 // @}
9857
9858 /** Make a new image which is a deep copy of this image. Use crop
9859 * or slice followed by copy to make a copy of only a portion of
9860 * the image. The new image uses the same memory layout as the
9861 * original, with holes compacted away. Note that the returned
9862 * Buffer is always of a non-const type T (ie:
9863 *
9864 * Buffer<const T>.copy() -> Buffer<T> rather than Buffer<const T>
9865 *
9866 * which is always safe, since we are making a deep copy. (The caller
9867 * can easily cast it back to Buffer<const T> if desired, which is
9868 * always safe and free.)
9869 */
9870 Buffer<not_const_T, D> copy(void *(*allocate_fn)(size_t) = nullptr,
9871 void (*deallocate_fn)(void *) = nullptr) const {
9872 Buffer<not_const_T, D> dst = Buffer<not_const_T, D>::make_with_shape_of(*this, allocate_fn, deallocate_fn);
9873 dst.copy_from(*this);
9874 return dst;
9875 }
9876
9877 /** Like copy(), but the copy is created in interleaved memory layout
9878 * (vs. keeping the same memory layout as the original). Requires that 'this'
9879 * has exactly 3 dimensions.
9880 */
9881 Buffer<not_const_T, D> copy_to_interleaved(void *(*allocate_fn)(size_t) = nullptr,
9882 void (*deallocate_fn)(void *) = nullptr) const {
9883 assert(dimensions() == 3);
9884 Buffer<not_const_T, D> dst = Buffer<not_const_T, D>::make_interleaved(nullptr, width(), height(), channels());
9885 dst.set_min(min(0), min(1), min(2));
9886 dst.allocate(allocate_fn, deallocate_fn);
9887 dst.copy_from(*this);
9888 return dst;
9889 }
9890
9891 /** Like copy(), but the copy is created in planar memory layout
9892 * (vs. keeping the same memory layout as the original).
9893 */
9894 Buffer<not_const_T, D> copy_to_planar(void *(*allocate_fn)(size_t) = nullptr,
9895 void (*deallocate_fn)(void *) = nullptr) const {
9896 std::vector<int> mins, extents;
9897 const int dims = dimensions();
9898 mins.reserve(dims);
9899 extents.reserve(dims);
9900 for (int d = 0; d < dims; ++d) {
9901 mins.push_back(dim(d).min());
9902 extents.push_back(dim(d).extent());
9903 }
9904 Buffer<not_const_T, D> dst = Buffer<not_const_T, D>(nullptr, extents);
9905 dst.set_min(mins);
9906 dst.allocate(allocate_fn, deallocate_fn);
9907 dst.copy_from(*this);
9908 return dst;
9909 }
9910
9911 /** Make a copy of the Buffer which shares the underlying host and/or device
9912 * allocations as the existing Buffer. This is purely syntactic sugar for
9913 * cases where you have a const reference to a Buffer but need a temporary
9914 * non-const copy (e.g. to make a call into AOT-generated Halide code), and want a terse
9915 * inline way to create a temporary. \code
9916 * void call_my_func(const Buffer<const uint8_t>& input) {
9917 * my_func(input.alias(), output);
9918 * }\endcode
9919 */
9920 inline Buffer<T, D> alias() const {
9921 return *this;
9922 }
9923
9924 /** Fill a Buffer with the values at the same coordinates in
9925 * another Buffer. Restricts itself to coordinates contained
9926 * within the intersection of the two buffers. If the two Buffers
9927 * are not in the same coordinate system, you will need to
9928 * translate the argument Buffer first. E.g. if you're blitting a
9929 * sprite onto a framebuffer, you'll want to translate the sprite
9930 * to the correct location first like so: \code
9931 * framebuffer.copy_from(sprite.translated({x, y})); \endcode
9932 */
9933 template<typename T2, int D2>
9934 void copy_from(Buffer<T2, D2> src) {
9935 static_assert(!std::is_const<T>::value, "Cannot call copy_from() on a Buffer<const T>");
9936 assert(!device_dirty() && "Cannot call Halide::Runtime::Buffer::copy_from on a device dirty destination.");
9937 assert(!src.device_dirty() && "Cannot call Halide::Runtime::Buffer::copy_from on a device dirty source.");
9938
9939 Buffer<T, D> dst(*this);
9940
9941 assert(src.dimensions() == dst.dimensions());
9942
9943 // Trim the copy to the region in common
9944 for (int i = 0; i < dimensions(); i++) {
9945 int min_coord = std::max(dst.dim(i).min(), src.dim(i).min());
9946 int max_coord = std::min(dst.dim(i).max(), src.dim(i).max());
9947 if (max_coord < min_coord) {
9948 // The buffers do not overlap.
9949 return;
9950 }
9951 dst.crop(i, min_coord, max_coord - min_coord + 1);
9952 src.crop(i, min_coord, max_coord - min_coord + 1);
9953 }
9954
9955 // If T is void, we need to do runtime dispatch to an
9956 // appropriately-typed lambda. We're copying, so we only care
9957 // about the element size. (If not, this should optimize away
9958 // into a static dispatch to the right-sized copy.)
9959 if (T_is_void ? (type().bytes() == 1) : (sizeof(not_void_T) == 1)) {
9960 using MemType = uint8_t;
9961 auto &typed_dst = (Buffer<MemType, D> &)dst;
9962 auto &typed_src = (Buffer<const MemType, D> &)src;
9963 typed_dst.for_each_value([&](MemType &dst, MemType src) { dst = src; }, typed_src);
9964 } else if (T_is_void ? (type().bytes() == 2) : (sizeof(not_void_T) == 2)) {
9965 using MemType = uint16_t;
9966 auto &typed_dst = (Buffer<MemType, D> &)dst;
9967 auto &typed_src = (Buffer<const MemType, D> &)src;
9968 typed_dst.for_each_value([&](MemType &dst, MemType src) { dst = src; }, typed_src);
9969 } else if (T_is_void ? (type().bytes() == 4) : (sizeof(not_void_T) == 4)) {
9970 using MemType = uint32_t;
9971 auto &typed_dst = (Buffer<MemType, D> &)dst;
9972 auto &typed_src = (Buffer<const MemType, D> &)src;
9973 typed_dst.for_each_value([&](MemType &dst, MemType src) { dst = src; }, typed_src);
9974 } else if (T_is_void ? (type().bytes() == 8) : (sizeof(not_void_T) == 8)) {
9975 using MemType = uint64_t;
9976 auto &typed_dst = (Buffer<MemType, D> &)dst;
9977 auto &typed_src = (Buffer<const MemType, D> &)src;
9978 typed_dst.for_each_value([&](MemType &dst, MemType src) { dst = src; }, typed_src);
9979 } else {
9980 assert(false && "type().bytes() must be 1, 2, 4, or 8");
9981 }
9982 set_host_dirty();
9983 }
9984
9985 /** Make an image that refers to a sub-range of this image along
9986 * the given dimension. Asserts that the crop region is within
9987 * the existing bounds: you cannot "crop outwards", even if you know there
9988 * is valid Buffer storage (e.g. because you already cropped inwards). */
9989 Buffer<T, D> cropped(int d, int min, int extent) const {
9990 // Make a fresh copy of the underlying buffer (but not a fresh
9991 // copy of the allocation, if there is one).
9992 Buffer<T, D> im = *this;
9993
9994 // This guarantees the prexisting device ref is dropped if the
9995 // device_crop call fails and maintains the buffer in a consistent
9996 // state.
9997 im.device_deallocate();
9998
9999 im.crop_host(d, min, extent);
10000 if (buf.device_interface != nullptr) {
10001 complete_device_crop(im);
10002 }
10003 return im;
10004 }
10005
10006 /** Crop an image in-place along the given dimension. This does
10007 * not move any data around in memory - it just changes the min
10008 * and extent of the given dimension. */
10009 void crop(int d, int min, int extent) {
10010 // An optimization for non-device buffers. For the device case,
10011 // a temp buffer is required, so reuse the not-in-place version.
10012 // TODO(zalman|abadams): Are nop crops common enough to special
10013 // case the device part of the if to do nothing?
10014 if (buf.device_interface != nullptr) {
10015 *this = cropped(d, min, extent);
10016 } else {
10017 crop_host(d, min, extent);
10018 }
10019 }
10020
10021 /** Make an image that refers to a sub-rectangle of this image along
10022 * the first N dimensions. Asserts that the crop region is within
10023 * the existing bounds. The cropped image may drop any device handle
10024 * if the device_interface cannot accomplish the crop in-place. */
10025 Buffer<T, D> cropped(const std::vector<std::pair<int, int>> &rect) const {
10026 // Make a fresh copy of the underlying buffer (but not a fresh
10027 // copy of the allocation, if there is one).
10028 Buffer<T, D> im = *this;
10029
10030 // This guarantees the prexisting device ref is dropped if the
10031 // device_crop call fails and maintains the buffer in a consistent
10032 // state.
10033 im.device_deallocate();
10034
10035 im.crop_host(rect);
10036 if (buf.device_interface != nullptr) {
10037 complete_device_crop(im);
10038 }
10039 return im;
10040 }
10041
10042 /** Crop an image in-place along the first N dimensions. This does
10043 * not move any data around in memory, nor does it free memory. It
10044 * just rewrites the min/extent of each dimension to refer to a
10045 * subregion of the same allocation. */
10046 void crop(const std::vector<std::pair<int, int>> &rect) {
10047 // An optimization for non-device buffers. For the device case,
10048 // a temp buffer is required, so reuse the not-in-place version.
10049 // TODO(zalman|abadams): Are nop crops common enough to special
10050 // case the device part of the if to do nothing?
10051 if (buf.device_interface != nullptr) {
10052 *this = cropped(rect);
10053 } else {
10054 crop_host(rect);
10055 }
10056 }
10057
10058 /** Make an image which refers to the same data with using
10059 * translated coordinates in the given dimension. Positive values
10060 * move the image data to the right or down relative to the
10061 * coordinate system. Drops any device handle. */
10062 Buffer<T, D> translated(int d, int dx) const {
10063 Buffer<T, D> im = *this;
10064 im.translate(d, dx);
10065 return im;
10066 }
10067
10068 /** Translate an image in-place along one dimension by changing
10069 * how it is indexed. Does not move any data around in memory. */
10070 void translate(int d, int delta) {
10071 assert(d >= 0 && d < this->dimensions());
10072 device_deallocate();
10073 buf.dim[d].min += delta;
10074 }
10075
10076 /** Make an image which refers to the same data translated along
10077 * the first N dimensions. */
10078 Buffer<T, D> translated(const std::vector<int> &delta) const {
10079 Buffer<T, D> im = *this;
10080 im.translate(delta);
10081 return im;
10082 }
10083
10084 /** Translate an image along the first N dimensions by changing
10085 * how it is indexed. Does not move any data around in memory. */
10086 void translate(const std::vector<int> &delta) {
10087 device_deallocate();
10088 assert(delta.size() <= static_cast<decltype(delta.size())>(std::numeric_limits<int>::max()));
10089 int limit = (int)delta.size();
10090 assert(limit <= dimensions());
10091 for (int i = 0; i < limit; i++) {
10092 translate(i, delta[i]);
10093 }
10094 }
10095
10096 /** Set the min coordinate of an image in the first N dimensions. */
10097 // @{
10098 void set_min(const std::vector<int> &mins) {
10099 assert(mins.size() <= static_cast<decltype(mins.size())>(dimensions()));
10100 device_deallocate();
10101 for (size_t i = 0; i < mins.size(); i++) {
10102 buf.dim[i].min = mins[i];
10103 }
10104 }
10105
10106 template<typename... Args>
10107 void set_min(Args... args) {
10108 set_min(std::vector<int>{args...});
10109 }
10110 // @}
10111
10112 /** Test if a given coordinate is within the bounds of an image. */
10113 // @{
10114 bool contains(const std::vector<int> &coords) const {
10115 assert(coords.size() <= static_cast<decltype(coords.size())>(dimensions()));
10116 for (size_t i = 0; i < coords.size(); i++) {
10117 if (coords[i] < dim((int)i).min() || coords[i] > dim((int)i).max()) {
10118 return false;
10119 }
10120 }
10121 return true;
10122 }
10123
10124 template<typename... Args>
10125 bool contains(Args... args) const {
10126 return contains(std::vector<int>{args...});
10127 }
10128 // @}
10129
10130 /** Make a buffer which refers to the same data in the same layout
10131 * using a swapped indexing order for the dimensions given. So
10132 * A = B.transposed(0, 1) means that A(i, j) == B(j, i), and more
10133 * strongly that A.address_of(i, j) == B.address_of(j, i). */
10134 Buffer<T, D> transposed(int d1, int d2) const {
10135 Buffer<T, D> im = *this;
10136 im.transpose(d1, d2);
10137 return im;
10138 }
10139
10140 /** Transpose a buffer in-place by changing how it is indexed. For
10141 * example, transpose(0, 1) on a two-dimensional buffer means that
10142 * the value referred to by coordinates (i, j) is now reached at
10143 * the coordinates (j, i), and vice versa. This is done by
10144 * reordering the per-dimension metadata rather than by moving
10145 * data around in memory, so other views of the same memory will
10146 * not see the data as having been transposed. */
10147 void transpose(int d1, int d2) {
10148 assert(d1 >= 0 && d1 < this->dimensions());
10149 assert(d2 >= 0 && d2 < this->dimensions());
10150 std::swap(buf.dim[d1], buf.dim[d2]);
10151 }
10152
10153 /** A generalized transpose: instead of swapping two dimensions,
10154 * pass a vector that lists each dimension index exactly once, in
10155 * the desired order. This does not move any data around in memory
10156 * - it just permutes how it is indexed. */
10157 void transpose(const std::vector<int> &order) {
10158 assert((int)order.size() == dimensions());
10159 if (dimensions() < 2) {
10160 // My, that was easy
10161 return;
10162 }
10163
10164 std::vector<int> order_sorted = order;
10165 for (size_t i = 1; i < order_sorted.size(); i++) {
10166 for (size_t j = i; j > 0 && order_sorted[j - 1] > order_sorted[j]; j--) {
10167 std::swap(order_sorted[j], order_sorted[j - 1]);
10168 transpose(j, j - 1);
10169 }
10170 }
10171 }
10172
10173 /** Make a buffer which refers to the same data in the same
10174 * layout using a different ordering of the dimensions. */
10175 Buffer<T, D> transposed(const std::vector<int> &order) const {
10176 Buffer<T, D> im = *this;
10177 im.transpose(order);
10178 return im;
10179 }
10180
10181 /** Make a lower-dimensional buffer that refers to one slice of
10182 * this buffer. */
10183 Buffer<T, D> sliced(int d, int pos) const {
10184 Buffer<T, D> im = *this;
10185
10186 // This guarantees the prexisting device ref is dropped if the
10187 // device_slice call fails and maintains the buffer in a consistent
10188 // state.
10189 im.device_deallocate();
10190
10191 im.slice_host(d, pos);
10192 if (buf.device_interface != nullptr) {
10193 complete_device_slice(im, d, pos);
10194 }
10195 return im;
10196 }
10197
10198 /** Make a lower-dimensional buffer that refers to one slice of this
10199 * buffer at the dimension's minimum. */
10200 inline Buffer<T, D> sliced(int d) const {
10201 return sliced(d, dim(d).min());
10202 }
10203
10204 /** Rewrite the buffer to refer to a single lower-dimensional
10205 * slice of itself along the given dimension at the given
10206 * coordinate. Does not move any data around or free the original
10207 * memory, so other views of the same data are unaffected. */
10208 void slice(int d, int pos) {
10209 // An optimization for non-device buffers. For the device case,
10210 // a temp buffer is required, so reuse the not-in-place version.
10211 // TODO(zalman|abadams): Are nop slices common enough to special
10212 // case the device part of the if to do nothing?
10213 if (buf.device_interface != nullptr) {
10214 *this = sliced(d, pos);
10215 } else {
10216 slice_host(d, pos);
10217 }
10218 }
10219
10220 /** Slice a buffer in-place at the dimension's minimum. */
10221 inline void slice(int d) {
10222 slice(d, dim(d).min());
10223 }
10224
10225 /** Make a new buffer that views this buffer as a single slice in a
10226 * higher-dimensional space. The new dimension has extent one and
10227 * the given min. This operation is the opposite of slice. As an
10228 * example, the following condition is true:
10229 *
10230 \code
10231 im2 = im.embedded(1, 17);
10232 &im(x, y, c) == &im2(x, 17, y, c);
10233 \endcode
10234 */
10235 Buffer<T, D> embedded(int d, int pos = 0) const {
10236 Buffer<T, D> im(*this);
10237 im.embed(d, pos);
10238 return im;
10239 }
10240
10241 /** Embed a buffer in-place, increasing the
10242 * dimensionality. */
10243 void embed(int d, int pos = 0) {
10244 assert(d >= 0 && d <= dimensions());
10245 add_dimension();
10246 translate(dimensions() - 1, pos);
10247 for (int i = dimensions() - 1; i > d; i--) {
10248 transpose(i, i - 1);
10249 }
10250 }
10251
10252 /** Add a new dimension with a min of zero and an extent of
10253 * one. The stride is the extent of the outermost dimension times
10254 * its stride. The new dimension is the last dimension. This is a
10255 * special case of embed. */
10256 void add_dimension() {
10257 const int dims = buf.dimensions;
10258 buf.dimensions++;
10259 if (buf.dim != shape) {
10260 // We're already on the heap. Reallocate.
10261 halide_dimension_t *new_shape = new halide_dimension_t[buf.dimensions];
10262 for (int i = 0; i < dims; i++) {
10263 new_shape[i] = buf.dim[i];
10264 }
10265 delete[] buf.dim;
10266 buf.dim = new_shape;
10267 } else if (dims == D) {
10268 // Transition from the in-class storage to the heap
10269 make_shape_storage(buf.dimensions);
10270 for (int i = 0; i < dims; i++) {
10271 buf.dim[i] = shape[i];
10272 }
10273 } else {
10274 // We still fit in the class
10275 }
10276 buf.dim[dims] = {0, 1, 0};
10277 if (dims == 0) {
10278 buf.dim[dims].stride = 1;
10279 } else {
10280 buf.dim[dims].stride = buf.dim[dims - 1].extent * buf.dim[dims - 1].stride;
10281 }
10282 }
10283
10284 /** Add a new dimension with a min of zero, an extent of one, and
10285 * the specified stride. The new dimension is the last
10286 * dimension. This is a special case of embed. */
10287 void add_dimension_with_stride(int s) {
10288 add_dimension();
10289 buf.dim[buf.dimensions - 1].stride = s;
10290 }
10291
10292 /** Methods for managing any GPU allocation. */
10293 // @{
10294 // Set the host dirty flag. Called by every operator()
10295 // access. Must be inlined so it can be hoisted out of loops.
10296 HALIDE_ALWAYS_INLINE
10297 void set_host_dirty(bool v = true) {
10298 assert((!v || !device_dirty()) && "Cannot set host dirty when device is already dirty. Call copy_to_host() before accessing the buffer from host.");
10299 buf.set_host_dirty(v);
10300 }
10301
10302 // Check if the device allocation is dirty. Called by
10303 // set_host_dirty, which is called by every accessor. Must be
10304 // inlined so it can be hoisted out of loops.
10305 HALIDE_ALWAYS_INLINE
10306 bool device_dirty() const {
10307 return buf.device_dirty();
10308 }
10309
10310 bool host_dirty() const {
10311 return buf.host_dirty();
10312 }
10313
10314 void set_device_dirty(bool v = true) {
10315 assert((!v || !host_dirty()) && "Cannot set device dirty when host is already dirty.");
10316 buf.set_device_dirty(v);
10317 }
10318
10319 int copy_to_host(void *ctx = nullptr) {
10320 if (device_dirty()) {
10321 return buf.device_interface->copy_to_host(ctx, &buf);
10322 }
10323 return 0;
10324 }
10325
10326 int copy_to_device(const struct halide_device_interface_t *device_interface, void *ctx = nullptr) {
10327 if (host_dirty()) {
10328 return device_interface->copy_to_device(ctx, &buf, device_interface);
10329 }
10330 return 0;
10331 }
10332
10333 int device_malloc(const struct halide_device_interface_t *device_interface, void *ctx = nullptr) {
10334 return device_interface->device_malloc(ctx, &buf, device_interface);
10335 }
10336
10337 int device_free(void *ctx = nullptr) {
10338 if (dev_ref_count) {
10339 assert(dev_ref_count->ownership == BufferDeviceOwnership::Allocated &&
10340 "Can't call device_free on an unmanaged or wrapped native device handle. "
10341 "Free the source allocation or call device_detach_native instead.");
10342 // Multiple people may be holding onto this dev field
10343 assert(dev_ref_count->count == 1 &&
10344 "Multiple Halide::Runtime::Buffer objects share this device "
10345 "allocation. Freeing it would create dangling references. "
10346 "Don't call device_free on Halide buffers that you have copied or "
10347 "passed by value.");
10348 }
10349 int ret = 0;
10350 if (buf.device_interface) {
10351 ret = buf.device_interface->device_free(ctx, &buf);
10352 }
10353 if (dev_ref_count) {
10354 delete dev_ref_count;
10355 dev_ref_count = nullptr;
10356 }
10357 return ret;
10358 }
10359
10360 int device_wrap_native(const struct halide_device_interface_t *device_interface,
10361 uint64_t handle, void *ctx = nullptr) {
10362 assert(device_interface);
10363 dev_ref_count = new DeviceRefCount;
10364 dev_ref_count->ownership = BufferDeviceOwnership::WrappedNative;
10365 return device_interface->wrap_native(ctx, &buf, handle, device_interface);
10366 }
10367
10368 int device_detach_native(void *ctx = nullptr) {
10369 assert(dev_ref_count &&
10370 dev_ref_count->ownership == BufferDeviceOwnership::WrappedNative &&
10371 "Only call device_detach_native on buffers wrapping a native "
10372 "device handle via device_wrap_native. This buffer was allocated "
10373 "using device_malloc, or is unmanaged. "
10374 "Call device_free or free the original allocation instead.");
10375 // Multiple people may be holding onto this dev field
10376 assert(dev_ref_count->count == 1 &&
10377 "Multiple Halide::Runtime::Buffer objects share this device "
10378 "allocation. Freeing it could create dangling references. "
10379 "Don't call device_detach_native on Halide buffers that you "
10380 "have copied or passed by value.");
10381 int ret = 0;
10382 if (buf.device_interface) {
10383 ret = buf.device_interface->detach_native(ctx, &buf);
10384 }
10385 delete dev_ref_count;
10386 dev_ref_count = nullptr;
10387 return ret;
10388 }
10389
10390 int device_and_host_malloc(const struct halide_device_interface_t *device_interface, void *ctx = nullptr) {
10391 return device_interface->device_and_host_malloc(ctx, &buf, device_interface);
10392 }
10393
10394 int device_and_host_free(const struct halide_device_interface_t *device_interface, void *ctx = nullptr) {
10395 if (dev_ref_count) {
10396 assert(dev_ref_count->ownership == BufferDeviceOwnership::AllocatedDeviceAndHost &&
10397 "Can't call device_and_host_free on a device handle not allocated with device_and_host_malloc. "
10398 "Free the source allocation or call device_detach_native instead.");
10399 // Multiple people may be holding onto this dev field
10400 assert(dev_ref_count->count == 1 &&
10401 "Multiple Halide::Runtime::Buffer objects share this device "
10402 "allocation. Freeing it would create dangling references. "
10403 "Don't call device_and_host_free on Halide buffers that you have copied or "
10404 "passed by value.");
10405 }
10406 int ret = 0;
10407 if (buf.device_interface) {
10408 ret = buf.device_interface->device_and_host_free(ctx, &buf);
10409 }
10410 if (dev_ref_count) {
10411 delete dev_ref_count;
10412 dev_ref_count = nullptr;
10413 }
10414 return ret;
10415 }
10416
10417 int device_sync(void *ctx = nullptr) {
10418 return buf.device_sync(ctx);
10419 }
10420
10421 bool has_device_allocation() const {
10422 return buf.device != 0;
10423 }
10424
10425 /** Return the method by which the device field is managed. */
10426 BufferDeviceOwnership device_ownership() const {
10427 if (dev_ref_count == nullptr) {
10428 return BufferDeviceOwnership::Allocated;
10429 }
10430 return dev_ref_count->ownership;
10431 }
10432 // @}
10433
10434 /** If you use the (x, y, c) indexing convention, then Halide
10435 * Buffers are stored planar by default. This function constructs
10436 * an interleaved RGB or RGBA image that can still be indexed
10437 * using (x, y, c). Passing it to a generator requires that the
10438 * generator has been compiled with support for interleaved (also
10439 * known as packed or chunky) memory layouts. */
10440 static Buffer<void, D> make_interleaved(halide_type_t t, int width, int height, int channels) {
10441 Buffer<void, D> im(t, channels, width, height);
10442 // Note that this is equivalent to calling transpose({2, 0, 1}),
10443 // but slightly more efficient.
10444 im.transpose(0, 1);
10445 im.transpose(1, 2);
10446 return im;
10447 }
10448
10449 /** If you use the (x, y, c) indexing convention, then Halide
10450 * Buffers are stored planar by default. This function constructs
10451 * an interleaved RGB or RGBA image that can still be indexed
10452 * using (x, y, c). Passing it to a generator requires that the
10453 * generator has been compiled with support for interleaved (also
10454 * known as packed or chunky) memory layouts. */
10455 static Buffer<T, D> make_interleaved(int width, int height, int channels) {
10456 return make_interleaved(static_halide_type(), width, height, channels);
10457 }
10458
10459 /** Wrap an existing interleaved image. */
10460 static Buffer<add_const_if_T_is_const<void>, D>
10461 make_interleaved(halide_type_t t, T *data, int width, int height, int channels) {
10462 Buffer<add_const_if_T_is_const<void>, D> im(t, data, channels, width, height);
10463 im.transpose(0, 1);
10464 im.transpose(1, 2);
10465 return im;
10466 }
10467
10468 /** Wrap an existing interleaved image. */
10469 static Buffer<T, D> make_interleaved(T *data, int width, int height, int channels) {
10470 return make_interleaved(static_halide_type(), data, width, height, channels);
10471 }
10472
10473 /** Make a zero-dimensional Buffer */
10474 static Buffer<add_const_if_T_is_const<void>, D> make_scalar(halide_type_t t) {
10475 Buffer<add_const_if_T_is_const<void>, 1> buf(t, 1);
10476 buf.slice(0, 0);
10477 return buf;
10478 }
10479
10480 /** Make a zero-dimensional Buffer */
10481 static Buffer<T, D> make_scalar() {
10482 Buffer<T, 1> buf(1);
10483 buf.slice(0, 0);
10484 return buf;
10485 }
10486
10487 /** Make a zero-dimensional Buffer that points to non-owned, existing data */
10488 static Buffer<T, D> make_scalar(T *data) {
10489 Buffer<T, 1> buf(data, 1);
10490 buf.slice(0, 0);
10491 return buf;
10492 }
10493
10494 /** Make a buffer with the same shape and memory nesting order as
10495 * another buffer. It may have a different type. */
10496 template<typename T2, int D2>
10497 static Buffer<T, D> make_with_shape_of(Buffer<T2, D2> src,
10498 void *(*allocate_fn)(size_t) = nullptr,
10499 void (*deallocate_fn)(void *) = nullptr) {
10500
10501 const halide_type_t dst_type = T_is_void ? src.type() : halide_type_of<typename std::remove_cv<not_void_T>::type>();
10502 return Buffer<>::make_with_shape_of_helper(dst_type, src.dimensions(), src.buf.dim,
10503 allocate_fn, deallocate_fn);
10504 }
10505
10506private:
10507 static Buffer<> make_with_shape_of_helper(halide_type_t dst_type,
10508 int dimensions,
10509 halide_dimension_t *shape,
10510 void *(*allocate_fn)(size_t),
10511 void (*deallocate_fn)(void *)) {
10512 // Reorder the dimensions of src to have strides in increasing order
10513 std::vector<int> swaps;
10514 for (int i = dimensions - 1; i > 0; i--) {
10515 for (int j = i; j > 0; j--) {
10516 if (shape[j - 1].stride > shape[j].stride) {
10517 std::swap(shape[j - 1], shape[j]);
10518 swaps.push_back(j);
10519 }
10520 }
10521 }
10522
10523 // Rewrite the strides to be dense (this messes up src, which
10524 // is why we took it by value).
10525 for (int i = 0; i < dimensions; i++) {
10526 if (i == 0) {
10527 shape[i].stride = 1;
10528 } else {
10529 shape[i].stride = shape[i - 1].extent * shape[i - 1].stride;
10530 }
10531 }
10532
10533 // Undo the dimension reordering
10534 while (!swaps.empty()) {
10535 int j = swaps.back();
10536 std::swap(shape[j - 1], shape[j]);
10537 swaps.pop_back();
10538 }
10539
10540 // Use an explicit runtime type, and make dst a Buffer<void>, to allow
10541 // using this method with Buffer<void> for either src or dst.
10542 Buffer<> dst(dst_type, nullptr, dimensions, shape);
10543 dst.allocate(allocate_fn, deallocate_fn);
10544
10545 return dst;
10546 }
10547
10548 template<typename... Args>
10549 HALIDE_ALWAYS_INLINE
10550 ptrdiff_t
10551 offset_of(int d, int first, Args... rest) const {
10552 return offset_of(d + 1, rest...) + (ptrdiff_t)this->buf.dim[d].stride * (first - this->buf.dim[d].min);
10553 }
10554
10555 HALIDE_ALWAYS_INLINE
10556 ptrdiff_t offset_of(int d) const {
10557 return 0;
10558 }
10559
10560 template<typename... Args>
10561 HALIDE_ALWAYS_INLINE
10562 storage_T *
10563 address_of(Args... args) const {
10564 if (T_is_void) {
10565 return (storage_T *)(this->buf.host) + offset_of(0, args...) * type().bytes();
10566 } else {
10567 return (storage_T *)(this->buf.host) + offset_of(0, args...);
10568 }
10569 }
10570
10571 HALIDE_ALWAYS_INLINE
10572 ptrdiff_t offset_of(const int *pos) const {
10573 ptrdiff_t offset = 0;
10574 for (int i = this->dimensions() - 1; i >= 0; i--) {
10575 offset += (ptrdiff_t)this->buf.dim[i].stride * (pos[i] - this->buf.dim[i].min);
10576 }
10577 return offset;
10578 }
10579
10580 HALIDE_ALWAYS_INLINE
10581 storage_T *address_of(const int *pos) const {
10582 if (T_is_void) {
10583 return (storage_T *)this->buf.host + offset_of(pos) * type().bytes();
10584 } else {
10585 return (storage_T *)this->buf.host + offset_of(pos);
10586 }
10587 }
10588
10589public:
10590 /** Get a pointer to the address of the min coordinate. */
10591 T *data() const {
10592 return (T *)(this->buf.host);
10593 }
10594
10595 /** Access elements. Use im(...) to get a reference to an element,
10596 * and use &im(...) to get the address of an element. If you pass
10597 * fewer arguments than the buffer has dimensions, the rest are
10598 * treated as their min coordinate. The non-const versions set the
10599 * host_dirty flag to true.
10600 */
10601 //@{
10602 template<typename... Args,
10603 typename = typename std::enable_if<AllInts<Args...>::value>::type>
10604 HALIDE_ALWAYS_INLINE const not_void_T &operator()(int first, Args... rest) const {
10605 static_assert(!T_is_void,
10606 "Cannot use operator() on Buffer<void> types");
10607 assert(!device_dirty());
10608 return *((const not_void_T *)(address_of(first, rest...)));
10609 }
10610
10611 HALIDE_ALWAYS_INLINE
10612 const not_void_T &
10613 operator()() const {
10614 static_assert(!T_is_void,
10615 "Cannot use operator() on Buffer<void> types");
10616 assert(!device_dirty());
10617 return *((const not_void_T *)(data()));
10618 }
10619
10620 HALIDE_ALWAYS_INLINE
10621 const not_void_T &
10622 operator()(const int *pos) const {
10623 static_assert(!T_is_void,
10624 "Cannot use operator() on Buffer<void> types");
10625 assert(!device_dirty());
10626 return *((const not_void_T *)(address_of(pos)));
10627 }
10628
10629 template<typename... Args,
10630 typename = typename std::enable_if<AllInts<Args...>::value>::type>
10631 HALIDE_ALWAYS_INLINE
10632 not_void_T &
10633 operator()(int first, Args... rest) {
10634 static_assert(!T_is_void,
10635 "Cannot use operator() on Buffer<void> types");
10636 set_host_dirty();
10637 return *((not_void_T *)(address_of(first, rest...)));
10638 }
10639
10640 HALIDE_ALWAYS_INLINE
10641 not_void_T &
10642 operator()() {
10643 static_assert(!T_is_void,
10644 "Cannot use operator() on Buffer<void> types");
10645 set_host_dirty();
10646 return *((not_void_T *)(data()));
10647 }
10648
10649 HALIDE_ALWAYS_INLINE
10650 not_void_T &
10651 operator()(const int *pos) {
10652 static_assert(!T_is_void,
10653 "Cannot use operator() on Buffer<void> types");
10654 set_host_dirty();
10655 return *((not_void_T *)(address_of(pos)));
10656 }
10657 // @}
10658
10659 /** Tests that all values in this buffer are equal to val. */
10660 bool all_equal(not_void_T val) const {
10661 bool all_equal = true;
10662 for_each_element([&](const int *pos) { all_equal &= (*this)(pos) == val; });
10663 return all_equal;
10664 }
10665
10666 Buffer<T, D> &fill(not_void_T val) {
10667 set_host_dirty();
10668 for_each_value([=](T &v) { v = val; });
10669 return *this;
10670 }
10671
10672private:
10673 /** Helper functions for for_each_value. */
10674 // @{
10675 template<int N>
10676 struct for_each_value_task_dim {
10677 std::ptrdiff_t extent;
10678 std::ptrdiff_t stride[N];
10679 };
10680
10681 // Given an array of strides, and a bunch of pointers to pointers
10682 // (all of different types), advance the pointers using the
10683 // strides.
10684 template<typename Ptr, typename... Ptrs>
10685 HALIDE_ALWAYS_INLINE static void advance_ptrs(const std::ptrdiff_t *stride, Ptr &ptr, Ptrs &...ptrs) {
10686 ptr += *stride;
10687 advance_ptrs(stride + 1, ptrs...);
10688 }
10689
10690 HALIDE_ALWAYS_INLINE
10691 static void advance_ptrs(const std::ptrdiff_t *) {
10692 }
10693
10694 template<typename Fn, typename Ptr, typename... Ptrs>
10695 HALIDE_NEVER_INLINE static void for_each_value_helper(Fn &&f, int d, bool innermost_strides_are_one,
10696 const for_each_value_task_dim<sizeof...(Ptrs) + 1> *t, Ptr ptr, Ptrs... ptrs) {
10697 if (d == 0) {
10698 if (innermost_strides_are_one) {
10699 Ptr end = ptr + t[0].extent;
10700 while (ptr != end) {
10701 f(*ptr++, (*ptrs++)...);
10702 }
10703 } else {
10704 for (std::ptrdiff_t i = t[0].extent; i != 0; i--) {
10705 f(*ptr, (*ptrs)...);
10706 advance_ptrs(t[0].stride, ptr, ptrs...);
10707 }
10708 }
10709 } else {
10710 for (std::ptrdiff_t i = t[d].extent; i != 0; i--) {
10711 for_each_value_helper(f, d - 1, innermost_strides_are_one, t, ptr, ptrs...);
10712 advance_ptrs(t[d].stride, ptr, ptrs...);
10713 }
10714 }
10715 }
10716
10717 template<int N>
10718 HALIDE_NEVER_INLINE static bool for_each_value_prep(for_each_value_task_dim<N> *t,
10719 const halide_buffer_t **buffers) {
10720 // Check the buffers all have clean host allocations
10721 for (int i = 0; i < N; i++) {
10722 if (buffers[i]->device) {
10723 assert(buffers[i]->host &&
10724 "Buffer passed to for_each_value has device allocation but no host allocation. Call allocate() and copy_to_host() first");
10725 assert(!buffers[i]->device_dirty() &&
10726 "Buffer passed to for_each_value is dirty on device. Call copy_to_host() first");
10727 } else {
10728 assert(buffers[i]->host &&
10729 "Buffer passed to for_each_value has no host or device allocation");
10730 }
10731 }
10732
10733 const int dimensions = buffers[0]->dimensions;
10734
10735 // Extract the strides in all the dimensions
10736 for (int i = 0; i < dimensions; i++) {
10737 for (int j = 0; j < N; j++) {
10738 assert(buffers[j]->dimensions == dimensions);
10739 assert(buffers[j]->dim[i].extent == buffers[0]->dim[i].extent &&
10740 buffers[j]->dim[i].min == buffers[0]->dim[i].min);
10741 const int s = buffers[j]->dim[i].stride;
10742 t[i].stride[j] = s;
10743 }
10744 t[i].extent = buffers[0]->dim[i].extent;
10745
10746 // Order the dimensions by stride, so that the traversal is cache-coherent.
10747 // Use the last dimension for this, because this is the source in copies.
10748 // It appears to be better to optimize read order than write order.
10749 for (int j = i; j > 0 && t[j].stride[N - 1] < t[j - 1].stride[N - 1]; j--) {
10750 std::swap(t[j], t[j - 1]);
10751 }
10752 }
10753
10754 // flatten dimensions where possible to make a larger inner
10755 // loop for autovectorization.
10756 int d = dimensions;
10757 for (int i = 1; i < d; i++) {
10758 bool flat = true;
10759 for (int j = 0; j < N; j++) {
10760 flat = flat && t[i - 1].stride[j] * t[i - 1].extent == t[i].stride[j];
10761 }
10762 if (flat) {
10763 t[i - 1].extent *= t[i].extent;
10764 for (int j = i; j < d; j++) {
10765 t[j] = t[j + 1];
10766 }
10767 i--;
10768 d--;
10769 t[d].extent = 1;
10770 }
10771 }
10772
10773 bool innermost_strides_are_one = true;
10774 if (dimensions > 0) {
10775 for (int i = 0; i < N; i++) {
10776 innermost_strides_are_one &= (t[0].stride[i] == 1);
10777 }
10778 }
10779
10780 return innermost_strides_are_one;
10781 }
10782
10783 template<typename Fn, typename... Args, int N = sizeof...(Args) + 1>
10784 void for_each_value_impl(Fn &&f, Args &&...other_buffers) const {
10785 if (dimensions() > 0) {
10786 Buffer<>::for_each_value_task_dim<N> *t =
10787 (Buffer<>::for_each_value_task_dim<N> *)HALIDE_ALLOCA((dimensions() + 1) * sizeof(for_each_value_task_dim<N>));
10788 // Move the preparatory code into a non-templated helper to
10789 // save code size.
10790 const halide_buffer_t *buffers[] = {&buf, (&other_buffers.buf)...};
10791 bool innermost_strides_are_one = Buffer<>::for_each_value_prep(t, buffers);
10792
10793 Buffer<>::for_each_value_helper(f, dimensions() - 1,
10794 innermost_strides_are_one,
10795 t,
10796 data(), (other_buffers.data())...);
10797 } else {
10798 f(*data(), (*other_buffers.data())...);
10799 }
10800 }
10801 // @}
10802
10803public:
10804 /** Call a function on every value in the buffer, and the
10805 * corresponding values in some number of other buffers of the
10806 * same size. The function should take a reference, const
10807 * reference, or value of the correct type for each buffer. This
10808 * effectively lifts a function of scalars to an element-wise
10809 * function of buffers. This produces code that the compiler can
10810 * autovectorize. This is slightly cheaper than for_each_element,
10811 * because it does not need to track the coordinates.
10812 *
10813 * Note that constness of Buffers is preserved: a const Buffer<T> (for either
10814 * 'this' or the other-buffers arguments) will allow mutation of the
10815 * buffer contents, while a Buffer<const T> will not. Attempting to specify
10816 * a mutable reference for the lambda argument of a Buffer<const T>
10817 * will result in a compilation error. */
10818 // @{
10819 template<typename Fn, typename... Args, int N = sizeof...(Args) + 1>
10820 HALIDE_ALWAYS_INLINE const Buffer<T, D> &for_each_value(Fn &&f, Args &&...other_buffers) const {
10821 for_each_value_impl(f, std::forward<Args>(other_buffers)...);
10822 return *this;
10823 }
10824
10825 template<typename Fn, typename... Args, int N = sizeof...(Args) + 1>
10826 HALIDE_ALWAYS_INLINE
10827 Buffer<T, D> &
10828 for_each_value(Fn &&f, Args &&...other_buffers) {
10829 for_each_value_impl(f, std::forward<Args>(other_buffers)...);
10830 return *this;
10831 }
10832 // @}
10833
10834private:
10835 // Helper functions for for_each_element
10836 struct for_each_element_task_dim {
10837 int min, max;
10838 };
10839
10840 /** If f is callable with this many args, call it. The first
10841 * argument is just to make the overloads distinct. Actual
10842 * overload selection is done using the enable_if. */
10843 template<typename Fn,
10844 typename... Args,
10845 typename = decltype(std::declval<Fn>()(std::declval<Args>()...))>
10846 HALIDE_ALWAYS_INLINE static void for_each_element_variadic(int, int, const for_each_element_task_dim *, Fn &&f, Args... args) {
10847 f(args...);
10848 }
10849
10850 /** If the above overload is impossible, we add an outer loop over
10851 * an additional argument and try again. */
10852 template<typename Fn,
10853 typename... Args>
10854 HALIDE_ALWAYS_INLINE static void for_each_element_variadic(double, int d, const for_each_element_task_dim *t, Fn &&f, Args... args) {
10855 for (int i = t[d].min; i <= t[d].max; i++) {
10856 for_each_element_variadic(0, d - 1, t, std::forward<Fn>(f), i, args...);
10857 }
10858 }
10859
10860 /** Determine the minimum number of arguments a callable can take
10861 * using the same trick. */
10862 template<typename Fn,
10863 typename... Args,
10864 typename = decltype(std::declval<Fn>()(std::declval<Args>()...))>
10865 HALIDE_ALWAYS_INLINE static int num_args(int, Fn &&, Args...) {
10866 return (int)(sizeof...(Args));
10867 }
10868
10869 /** The recursive version is only enabled up to a recursion limit
10870 * of 256. This catches callables that aren't callable with any
10871 * number of ints. */
10872 template<typename Fn,
10873 typename... Args>
10874 HALIDE_ALWAYS_INLINE static int num_args(double, Fn &&f, Args... args) {
10875 static_assert(sizeof...(args) <= 256,
10876 "Callable passed to for_each_element must accept either a const int *,"
10877 " or up to 256 ints. No such operator found. Expect infinite template recursion.");
10878 return num_args(0, std::forward<Fn>(f), 0, args...);
10879 }
10880
10881 /** A version where the callable takes a position array instead,
10882 * with compile-time recursion on the dimensionality. This
10883 * overload is preferred to the one below using the same int vs
10884 * double trick as above, but is impossible once d hits -1 using
10885 * std::enable_if. */
10886 template<int d,
10887 typename Fn,
10888 typename = typename std::enable_if<(d >= 0)>::type>
10889 HALIDE_ALWAYS_INLINE static void for_each_element_array_helper(int, const for_each_element_task_dim *t, Fn &&f, int *pos) {
10890 for (pos[d] = t[d].min; pos[d] <= t[d].max; pos[d]++) {
10891 for_each_element_array_helper<d - 1>(0, t, std::forward<Fn>(f), pos);
10892 }
10893 }
10894
10895 /** Base case for recursion above. */
10896 template<int d,
10897 typename Fn,
10898 typename = typename std::enable_if<(d < 0)>::type>
10899 HALIDE_ALWAYS_INLINE static void for_each_element_array_helper(double, const for_each_element_task_dim *t, Fn &&f, int *pos) {
10900 f(pos);
10901 }
10902
10903 /** A run-time-recursive version (instead of
10904 * compile-time-recursive) that requires the callable to take a
10905 * pointer to a position array instead. Dispatches to the
10906 * compile-time-recursive version once the dimensionality gets
10907 * small. */
10908 template<typename Fn>
10909 static void for_each_element_array(int d, const for_each_element_task_dim *t, Fn &&f, int *pos) {
10910 if (d == -1) {
10911 f(pos);
10912 } else if (d == 0) {
10913 // Once the dimensionality gets small enough, dispatch to
10914 // a compile-time-recursive version for better codegen of
10915 // the inner loops.
10916 for_each_element_array_helper<0, Fn>(0, t, std::forward<Fn>(f), pos);
10917 } else if (d == 1) {
10918 for_each_element_array_helper<1, Fn>(0, t, std::forward<Fn>(f), pos);
10919 } else if (d == 2) {
10920 for_each_element_array_helper<2, Fn>(0, t, std::forward<Fn>(f), pos);
10921 } else if (d == 3) {
10922 for_each_element_array_helper<3, Fn>(0, t, std::forward<Fn>(f), pos);
10923 } else {
10924 for (pos[d] = t[d].min; pos[d] <= t[d].max; pos[d]++) {
10925 for_each_element_array(d - 1, t, std::forward<Fn>(f), pos);
10926 }
10927 }
10928 }
10929
10930 /** We now have two overloads for for_each_element. This one
10931 * triggers if the callable takes a const int *.
10932 */
10933 template<typename Fn,
10934 typename = decltype(std::declval<Fn>()((const int *)nullptr))>
10935 static void for_each_element(int, int dims, const for_each_element_task_dim *t, Fn &&f, int check = 0) {
10936 int *pos = (int *)HALIDE_ALLOCA(dims * sizeof(int));
10937 for_each_element_array(dims - 1, t, std::forward<Fn>(f), pos);
10938 }
10939
10940 /** This one triggers otherwise. It treats the callable as
10941 * something that takes some number of ints. */
10942 template<typename Fn>
10943 HALIDE_ALWAYS_INLINE static void for_each_element(double, int dims, const for_each_element_task_dim *t, Fn &&f) {
10944 int args = num_args(0, std::forward<Fn>(f));
10945 assert(dims >= args);
10946 for_each_element_variadic(0, args - 1, t, std::forward<Fn>(f));
10947 }
10948
10949 template<typename Fn>
10950 void for_each_element_impl(Fn &&f) const {
10951 for_each_element_task_dim *t =
10952 (for_each_element_task_dim *)HALIDE_ALLOCA(dimensions() * sizeof(for_each_element_task_dim));
10953 for (int i = 0; i < dimensions(); i++) {
10954 t[i].min = dim(i).min();
10955 t[i].max = dim(i).max();
10956 }
10957 for_each_element(0, dimensions(), t, std::forward<Fn>(f));
10958 }
10959
10960public:
10961 /** Call a function at each site in a buffer. This is likely to be
10962 * much slower than using Halide code to populate a buffer, but is
10963 * convenient for tests. If the function has more arguments than the
10964 * buffer has dimensions, the remaining arguments will be zero. If it
10965 * has fewer arguments than the buffer has dimensions then the last
10966 * few dimensions of the buffer are not iterated over. For example,
10967 * the following code exploits this to set a floating point RGB image
10968 * to red:
10969
10970 \code
10971 Buffer<float, 3> im(100, 100, 3);
10972 im.for_each_element([&](int x, int y) {
10973 im(x, y, 0) = 1.0f;
10974 im(x, y, 1) = 0.0f;
10975 im(x, y, 2) = 0.0f:
10976 });
10977 \endcode
10978
10979 * The compiled code is equivalent to writing the a nested for loop,
10980 * and compilers are capable of optimizing it in the same way.
10981 *
10982 * If the callable can be called with an int * as the sole argument,
10983 * that version is called instead. Each location in the buffer is
10984 * passed to it in a coordinate array. This version is higher-overhead
10985 * than the variadic version, but is useful for writing generic code
10986 * that accepts buffers of arbitrary dimensionality. For example, the
10987 * following sets the value at all sites in an arbitrary-dimensional
10988 * buffer to their first coordinate:
10989
10990 \code
10991 im.for_each_element([&](const int *pos) {im(pos) = pos[0];});
10992 \endcode
10993
10994 * It is also possible to use for_each_element to iterate over entire
10995 * rows or columns by cropping the buffer to a single column or row
10996 * respectively and iterating over elements of the result. For example,
10997 * to set the diagonal of the image to 1 by iterating over the columns:
10998
10999 \code
11000 Buffer<float, 3> im(100, 100, 3);
11001 im.sliced(1, 0).for_each_element([&](int x, int c) {
11002 im(x, x, c) = 1.0f;
11003 });
11004 \endcode
11005
11006 * Or, assuming the memory layout is known to be dense per row, one can
11007 * memset each row of an image like so:
11008
11009 \code
11010 Buffer<float, 3> im(100, 100, 3);
11011 im.sliced(0, 0).for_each_element([&](int y, int c) {
11012 memset(&im(0, y, c), 0, sizeof(float) * im.width());
11013 });
11014 \endcode
11015
11016 */
11017 // @{
11018 template<typename Fn>
11019 HALIDE_ALWAYS_INLINE const Buffer<T, D> &for_each_element(Fn &&f) const {
11020 for_each_element_impl(f);
11021 return *this;
11022 }
11023
11024 template<typename Fn>
11025 HALIDE_ALWAYS_INLINE
11026 Buffer<T, D> &
11027 for_each_element(Fn &&f) {
11028 for_each_element_impl(f);
11029 return *this;
11030 }
11031 // @}
11032
11033private:
11034 template<typename Fn>
11035 struct FillHelper {
11036 Fn f;
11037 Buffer<T, D> *buf;
11038
11039 template<typename... Args,
11040 typename = decltype(std::declval<Fn>()(std::declval<Args>()...))>
11041 void operator()(Args... args) {
11042 (*buf)(args...) = f(args...);
11043 }
11044
11045 FillHelper(Fn &&f, Buffer<T, D> *buf)
11046 : f(std::forward<Fn>(f)), buf(buf) {
11047 }
11048 };
11049
11050public:
11051 /** Fill a buffer by evaluating a callable at every site. The
11052 * callable should look much like a callable passed to
11053 * for_each_element, but it should return the value that should be
11054 * stored to the coordinate corresponding to the arguments. */
11055 template<typename Fn,
11056 typename = typename std::enable_if<!std::is_arithmetic<typename std::decay<Fn>::type>::value>::type>
11057 Buffer<T, D> &fill(Fn &&f) {
11058 // We'll go via for_each_element. We need a variadic wrapper lambda.
11059 FillHelper<Fn> wrapper(std::forward<Fn>(f), this);
11060 return for_each_element(wrapper);
11061 }
11062
11063 /** Check if an input buffer passed extern stage is a querying
11064 * bounds. Compared to doing the host pointer check directly,
11065 * this both adds clarity to code and will facilitate moving to
11066 * another representation for bounds query arguments. */
11067 bool is_bounds_query() const {
11068 return buf.is_bounds_query();
11069 }
11070
11071 /** Convenient check to verify that all of the interesting bytes in the Buffer
11072 * are initialized under MSAN. Note that by default, we use for_each_value() here so that
11073 * we skip any unused padding that isn't part of the Buffer; this isn't efficient,
11074 * but in MSAN mode, it doesn't matter. (Pass true for the flag to force check
11075 * the entire Buffer storage.) */
11076 void msan_check_mem_is_initialized(bool entire = false) const {
11077#if defined(__has_feature)
11078#if __has_feature(memory_sanitizer)
11079 if (entire) {
11080 __msan_check_mem_is_initialized(data(), size_in_bytes());
11081 } else {
11082 for_each_value([](T &v) { __msan_check_mem_is_initialized(&v, sizeof(T)); ; });
11083 }
11084#endif
11085#endif
11086 }
11087};
11088
11089} // namespace Runtime
11090} // namespace Halide
11091
11092#undef HALIDE_ALLOCA
11093
11094#endif // HALIDE_RUNTIME_IMAGE_H
11095
11096namespace Halide {
11097
11098template<typename T = void>
11099class Buffer;
11100
11101namespace Internal {
11102
11103struct BufferContents {
11104 mutable RefCount ref_count;
11105 std::string name;
11106 Runtime::Buffer<> buf;
11107};
11108
11109Expr buffer_accessor(const Buffer<> &buf, const std::vector<Expr> &args);
11110
11111template<typename... Args>
11112struct all_ints_and_optional_name : std::false_type {};
11113
11114template<typename First, typename... Rest>
11115struct all_ints_and_optional_name<First, Rest...> : meta_and<std::is_convertible<First, int>,
11116 all_ints_and_optional_name<Rest...>> {};
11117
11118template<typename T>
11119struct all_ints_and_optional_name<T> : meta_or<std::is_convertible<T, std::string>,
11120 std::is_convertible<T, int>> {};
11121
11122template<>
11123struct all_ints_and_optional_name<> : std::true_type {};
11124
11125template<typename T,
11126 typename = typename std::enable_if<!std::is_convertible<T, std::string>::value>::type>
11127std::string get_name_from_end_of_parameter_pack(T &&) {
11128 return "";
11129}
11130
11131inline std::string get_name_from_end_of_parameter_pack(const std::string &n) {
11132 return n;
11133}
11134
11135inline std::string get_name_from_end_of_parameter_pack() {
11136 return "";
11137}
11138
11139template<typename First,
11140 typename Second,
11141 typename... Args>
11142std::string get_name_from_end_of_parameter_pack(First first, Second second, Args &&...rest) {
11143 return get_name_from_end_of_parameter_pack(second, std::forward<Args>(rest)...);
11144}
11145
11146inline void get_shape_from_start_of_parameter_pack_helper(std::vector<int> &, const std::string &) {
11147}
11148
11149inline void get_shape_from_start_of_parameter_pack_helper(std::vector<int> &) {
11150}
11151
11152template<typename... Args>
11153void get_shape_from_start_of_parameter_pack_helper(std::vector<int> &result, int x, Args &&...rest) {
11154 result.push_back(x);
11155 get_shape_from_start_of_parameter_pack_helper(result, std::forward<Args>(rest)...);
11156}
11157
11158template<typename... Args>
11159std::vector<int> get_shape_from_start_of_parameter_pack(Args &&...args) {
11160 std::vector<int> result;
11161 get_shape_from_start_of_parameter_pack_helper(result, std::forward<Args>(args)...);
11162 return result;
11163}
11164
11165template<typename T, typename T2>
11166using add_const_if_T_is_const = typename std::conditional<std::is_const<T>::value, const T2, T2>::type;
11167
11168// Helpers to produce the name of a Buffer element type (a Halide
11169// scalar type, or void, possibly with const). Useful for an error
11170// messages.
11171template<typename T>
11172void buffer_type_name_non_const(std::ostream &s) {
11173 s << type_to_c_type(type_of<T>(), false);
11174}
11175
11176template<>
11177inline void buffer_type_name_non_const<void>(std::ostream &s) {
11178 s << "void";
11179}
11180
11181template<typename T>
11182std::string buffer_type_name() {
11183 std::ostringstream oss;
11184 if (std::is_const<T>::value) {
11185 oss << "const ";
11186 }
11187 buffer_type_name_non_const<typename std::remove_const<T>::type>(oss);
11188 return oss.str();
11189}
11190
11191} // namespace Internal
11192
11193/** A Halide::Buffer is a named shared reference to a
11194 * Halide::Runtime::Buffer.
11195 *
11196 * A Buffer<T1> can refer to a Buffer<T2> if T1 is const whenever T2
11197 * is const, and either T1 = T2 or T1 is void. A Buffer<void> can
11198 * refer to any Buffer of any non-const type, and the default
11199 * template parameter is T = void.
11200 */
11201template<typename T>
11202class Buffer {
11203 Internal::IntrusivePtr<Internal::BufferContents> contents;
11204
11205 template<typename T2>
11206 friend class Buffer;
11207
11208 template<typename T2>
11209 static void assert_can_convert_from(const Buffer<T2> &other) {
11210 if (!other.defined()) {
11211 // Avoid UB of deferencing offset of a null contents ptr
11212 static_assert((!std::is_const<T2>::value || std::is_const<T>::value),
11213 "Can't convert from a Buffer<const T> to a Buffer<T>");
11214 static_assert(std::is_same<typename std::remove_const<T>::type,
11215 typename std::remove_const<T2>::type>::value ||
11216 std::is_void<T>::value ||
11217 std::is_void<T2>::value,
11218 "type mismatch constructing Buffer");
11219 } else {
11220 // Don't delegate to
11221 // Runtime::Buffer<T>::assert_can_convert_from. It might
11222 // not assert is NDEBUG is defined. user_assert is
11223 // friendlier anyway because it reports line numbers when
11224 // debugging symbols are found, it throws an exception
11225 // when exceptions are enabled, and we can print the
11226 // actual types in question.
11227 user_assert(Runtime::Buffer<T>::can_convert_from(*(other.get())))
11228 << "Type mismatch constructing Buffer. Can't construct Buffer<"
11229 << Internal::buffer_type_name<T>() << "> from Buffer<"
11230 << type_to_c_type(other.type(), false) << ">\n";
11231 }
11232 }
11233
11234public:
11235 typedef T ElemType;
11236
11237 // This class isn't final (and is subclassed from the Python binding
11238 // code, at least) so it needs a virtual dtor.
11239 virtual ~Buffer() = default;
11240
11241 /** Make a null Buffer, which points to no Runtime::Buffer */
11242 Buffer() = default;
11243
11244 /** Trivial copy constructor. */
11245 Buffer(const Buffer &that) = default;
11246
11247 /** Trivial copy assignment operator. */
11248 Buffer &operator=(const Buffer &that) = default;
11249
11250 /** Trivial move assignment operator. */
11251 Buffer &operator=(Buffer &&) noexcept = default;
11252
11253 /** Make a Buffer from a Buffer of a different type */
11254 template<typename T2>
11255 Buffer(const Buffer<T2> &other)
11256 : contents(other.contents) {
11257 assert_can_convert_from(other);
11258 }
11259
11260 /** Move construct from a Buffer of a different type */
11261 template<typename T2>
11262 Buffer(Buffer<T2> &&other) noexcept {
11263 assert_can_convert_from(other);
11264 contents = std::move(other.contents);
11265 }
11266
11267 /** Construct a Buffer that captures and owns an rvalue Runtime::Buffer */
11268 template<int D>
11269 Buffer(Runtime::Buffer<T, D> &&buf, const std::string &name = "")
11270 : contents(new Internal::BufferContents) {
11271 contents->buf = std::move(buf);
11272 if (name.empty()) {
11273 contents->name = Internal::make_entity_name(this, "Halide:.*:Buffer<.*>", 'b');
11274 } else {
11275 contents->name = name;
11276 }
11277 }
11278
11279 /** Constructors that match Runtime::Buffer with two differences:
11280 * 1) They take a Type instead of a halide_type_t
11281 * 2) There is an optional last string argument that gives the buffer a specific name
11282 */
11283 // @{
11284 template<typename... Args,
11285 typename = typename std::enable_if<Internal::all_ints_and_optional_name<Args...>::value>::type>
11286 explicit Buffer(Type t,
11287 int first, Args... rest)
11288 : Buffer(Runtime::Buffer<T>(t, Internal::get_shape_from_start_of_parameter_pack(first, rest...)),
11289 Internal::get_name_from_end_of_parameter_pack(rest...)) {
11290 }
11291
11292 explicit Buffer(const halide_buffer_t &buf,
11293 const std::string &name = "")
11294 : Buffer(Runtime::Buffer<T>(buf), name) {
11295 }
11296
11297 template<typename... Args,
11298 typename = typename std::enable_if<Internal::all_ints_and_optional_name<Args...>::value>::type>
11299 explicit Buffer(int first, Args... rest)
11300 : Buffer(Runtime::Buffer<T>(Internal::get_shape_from_start_of_parameter_pack(first, rest...)),
11301 Internal::get_name_from_end_of_parameter_pack(rest...)) {
11302 }
11303
11304 explicit Buffer(Type t,
11305 const std::vector<int> &sizes,
11306 const std::string &name = "")
11307 : Buffer(Runtime::Buffer<T>(t, sizes), name) {
11308 }
11309
11310 explicit Buffer(Type t,
11311 const std::vector<int> &sizes,
11312 const std::vector<int> &storage_order,
11313 const std::string &name = "")
11314 : Buffer(Runtime::Buffer<T>(t, sizes, storage_order), name) {
11315 }
11316
11317 explicit Buffer(const std::vector<int> &sizes,
11318 const std::string &name = "")
11319 : Buffer(Runtime::Buffer<T>(sizes), name) {
11320 }
11321
11322 explicit Buffer(const std::vector<int> &sizes,
11323 const std::vector<int> &storage_order,
11324 const std::string &name = "")
11325 : Buffer(Runtime::Buffer<T>(sizes, storage_order), name) {
11326 }
11327
11328 template<typename Array, size_t N>
11329 explicit Buffer(Array (&vals)[N],
11330 const std::string &name = "")
11331 : Buffer(Runtime::Buffer<T>(vals), name) {
11332 }
11333
11334 template<typename... Args,
11335 typename = typename std::enable_if<Internal::all_ints_and_optional_name<Args...>::value>::type>
11336 explicit Buffer(Type t,
11337 Internal::add_const_if_T_is_const<T, void> *data,
11338 int first, Args &&...rest)
11339 : Buffer(Runtime::Buffer<T>(t, data, Internal::get_shape_from_start_of_parameter_pack(first, rest...)),
11340 Internal::get_name_from_end_of_parameter_pack(rest...)) {
11341 }
11342
11343 template<typename... Args,
11344 typename = typename std::enable_if<Internal::all_ints_and_optional_name<Args...>::value>::type>
11345 explicit Buffer(Type t,
11346 Internal::add_const_if_T_is_const<T, void> *data,
11347 const std::vector<int> &sizes,
11348 const std::string &name = "")
11349 : Buffer(Runtime::Buffer<T>(t, data, sizes, name)) {
11350 }
11351
11352 template<typename... Args,
11353 typename = typename std::enable_if<Internal::all_ints_and_optional_name<Args...>::value>::type>
11354 explicit Buffer(T *data,
11355 int first, Args &&...rest)
11356 : Buffer(Runtime::Buffer<T>(data, Internal::get_shape_from_start_of_parameter_pack(first, rest...)),
11357 Internal::get_name_from_end_of_parameter_pack(rest...)) {
11358 }
11359
11360 explicit Buffer(T *data,
11361 const std::vector<int> &sizes,
11362 const std::string &name = "")
11363 : Buffer(Runtime::Buffer<T>(data, sizes), name) {
11364 }
11365
11366 explicit Buffer(Type t,
11367 Internal::add_const_if_T_is_const<T, void> *data,
11368 const std::vector<int> &sizes,
11369 const std::string &name = "")
11370 : Buffer(Runtime::Buffer<T>(t, data, sizes), name) {
11371 }
11372
11373 explicit Buffer(Type t,
11374 Internal::add_const_if_T_is_const<T, void> *data,
11375 int d,
11376 const halide_dimension_t *shape,
11377 const std::string &name = "")
11378 : Buffer(Runtime::Buffer<T>(t, data, d, shape), name) {
11379 }
11380
11381 explicit Buffer(T *data,
11382 int d,
11383 const halide_dimension_t *shape,
11384 const std::string &name = "")
11385 : Buffer(Runtime::Buffer<T>(data, d, shape), name) {
11386 }
11387
11388 static Buffer<T> make_scalar(const std::string &name = "") {
11389 return Buffer<T>(Runtime::Buffer<T>::make_scalar(), name);
11390 }
11391
11392 static Buffer<> make_scalar(Type t, const std::string &name = "") {
11393 return Buffer<>(Runtime::Buffer<>::make_scalar(t), name);
11394 }
11395
11396 static Buffer<T> make_scalar(T *data, const std::string &name = "") {
11397 return Buffer<T>(Runtime::Buffer<T>::make_scalar(data), name);
11398 }
11399
11400 static Buffer<T> make_interleaved(int width, int height, int channels, const std::string &name = "") {
11401 return Buffer<T>(Runtime::Buffer<T>::make_interleaved(width, height, channels),
11402 name);
11403 }
11404
11405 static Buffer<> make_interleaved(Type t, int width, int height, int channels, const std::string &name = "") {
11406 return Buffer<>(Runtime::Buffer<>::make_interleaved(t, width, height, channels),
11407 name);
11408 }
11409
11410 static Buffer<T> make_interleaved(T *data, int width, int height, int channels, const std::string &name = "") {
11411 return Buffer<T>(Runtime::Buffer<T>::make_interleaved(data, width, height, channels),
11412 name);
11413 }
11414
11415 static Buffer<Internal::add_const_if_T_is_const<T, void>>
11416 make_interleaved(Type t, T *data, int width, int height, int channels, const std::string &name = "") {
11417 using T2 = Internal::add_const_if_T_is_const<T, void>;
11418 return Buffer<T2>(Runtime::Buffer<T2>::make_interleaved(t, data, width, height, channels),
11419 name);
11420 }
11421
11422 template<typename T2>
11423 static Buffer<T> make_with_shape_of(Buffer<T2> src,
11424 void *(*allocate_fn)(size_t) = nullptr,
11425 void (*deallocate_fn)(void *) = nullptr,
11426 const std::string &name = "") {
11427 return Buffer<T>(Runtime::Buffer<T>::make_with_shape_of(*src.get(), allocate_fn, deallocate_fn),
11428 name);
11429 }
11430
11431 template<typename T2>
11432 static Buffer<T> make_with_shape_of(const Runtime::Buffer<T2> &src,
11433 void *(*allocate_fn)(size_t) = nullptr,
11434 void (*deallocate_fn)(void *) = nullptr,
11435 const std::string &name = "") {
11436 return Buffer<T>(Runtime::Buffer<T>::make_with_shape_of(src, allocate_fn, deallocate_fn),
11437 name);
11438 }
11439 // @}
11440
11441 /** Buffers are optionally named. */
11442 // @{
11443 void set_name(const std::string &n) {
11444 contents->name = n;
11445 }
11446
11447 const std::string &name() const {
11448 return contents->name;
11449 }
11450 // @}
11451
11452 /** Check if two Buffer objects point to the same underlying Buffer */
11453 template<typename T2>
11454 bool same_as(const Buffer<T2> &other) const {
11455 return (const void *)(contents.get()) == (const void *)(other.contents.get());
11456 }
11457
11458 /** Check if this Buffer refers to an existing
11459 * Buffer. Default-constructed Buffer objects do not refer to any
11460 * existing Buffer. */
11461 bool defined() const {
11462 return contents.defined();
11463 }
11464
11465 /** Get a pointer to the underlying Runtime::Buffer */
11466 // @{
11467 Runtime::Buffer<T> *get() {
11468 // It's already type-checked, so no need to use as<T>.
11469 return (Runtime::Buffer<T> *)(&contents->buf);
11470 }
11471 const Runtime::Buffer<T> *get() const {
11472 return (const Runtime::Buffer<T> *)(&contents->buf);
11473 }
11474 // @}
11475
11476 // We forward numerous methods from the underlying Buffer
11477#define HALIDE_BUFFER_FORWARD_CONST(method) \
11478 template<typename... Args> \
11479 auto method(Args &&...args) const->decltype(std::declval<const Runtime::Buffer<T>>().method(std::forward<Args>(args)...)) { \
11480 user_assert(defined()) << "Undefined buffer calling const method " #method "\n"; \
11481 return get()->method(std::forward<Args>(args)...); \
11482 }
11483
11484#define HALIDE_BUFFER_FORWARD(method) \
11485 template<typename... Args> \
11486 auto method(Args &&...args)->decltype(std::declval<Runtime::Buffer<T>>().method(std::forward<Args>(args)...)) { \
11487 user_assert(defined()) << "Undefined buffer calling method " #method "\n"; \
11488 return get()->method(std::forward<Args>(args)...); \
11489 }
11490
11491// This is a weird-looking but effective workaround for a deficiency in "perfect forwarding":
11492// namely, it can't really handle initializer-lists. The idea here is that we declare
11493// the expected type to be passed on, and that allows the compiler to handle it.
11494// The weirdness comes in with the variadic macro: the problem is that the type
11495// we want to forward might be something like `std::vector<std::pair<int, int>>`,
11496// which contains a comma, which throws a big wrench in C++ macro system.
11497// However... since all we really need to do is capture the remainder of the macro,
11498// and forward it as is, we can just use ... to allow an arbitrary number of commas,
11499// then use __VA_ARGS__ to forward the mess as-is, and while it looks horrible, it
11500// works.
11501#define HALIDE_BUFFER_FORWARD_INITIALIZER_LIST(method, ...) \
11502 inline auto method(const __VA_ARGS__ &a)->decltype(std::declval<Runtime::Buffer<T>>().method(a)) { \
11503 user_assert(defined()) << "Undefined buffer calling method " #method "\n"; \
11504 return get()->method(a); \
11505 }
11506
11507 /** Does the same thing as the equivalent Halide::Runtime::Buffer method */
11508 // @{
11509 HALIDE_BUFFER_FORWARD(raw_buffer)
11510 HALIDE_BUFFER_FORWARD_CONST(raw_buffer)
11511 HALIDE_BUFFER_FORWARD_CONST(dimensions)
11512 HALIDE_BUFFER_FORWARD_CONST(dim)
11513 HALIDE_BUFFER_FORWARD_CONST(width)
11514 HALIDE_BUFFER_FORWARD_CONST(height)
11515 HALIDE_BUFFER_FORWARD_CONST(channels)
11516 HALIDE_BUFFER_FORWARD_CONST(min)
11517 HALIDE_BUFFER_FORWARD_CONST(extent)
11518 HALIDE_BUFFER_FORWARD_CONST(stride)
11519 HALIDE_BUFFER_FORWARD_CONST(left)
11520 HALIDE_BUFFER_FORWARD_CONST(right)
11521 HALIDE_BUFFER_FORWARD_CONST(top)
11522 HALIDE_BUFFER_FORWARD_CONST(bottom)
11523 HALIDE_BUFFER_FORWARD_CONST(number_of_elements)
11524 HALIDE_BUFFER_FORWARD_CONST(size_in_bytes)
11525 HALIDE_BUFFER_FORWARD_CONST(begin)
11526 HALIDE_BUFFER_FORWARD_CONST(end)
11527 HALIDE_BUFFER_FORWARD(data)
11528 HALIDE_BUFFER_FORWARD_CONST(data)
11529 HALIDE_BUFFER_FORWARD_CONST(contains)
11530 HALIDE_BUFFER_FORWARD(crop)
11531 HALIDE_BUFFER_FORWARD_INITIALIZER_LIST(crop, std::vector<std::pair<int, int>>)
11532 HALIDE_BUFFER_FORWARD(slice)
11533 HALIDE_BUFFER_FORWARD_CONST(sliced)
11534 HALIDE_BUFFER_FORWARD(embed)
11535 HALIDE_BUFFER_FORWARD_CONST(embedded)
11536 HALIDE_BUFFER_FORWARD(set_min)
11537 HALIDE_BUFFER_FORWARD(translate)
11538 HALIDE_BUFFER_FORWARD_INITIALIZER_LIST(translate, std::vector<int>)
11539 HALIDE_BUFFER_FORWARD(transpose)
11540 HALIDE_BUFFER_FORWARD_CONST(transposed)
11541 HALIDE_BUFFER_FORWARD(add_dimension)
11542 HALIDE_BUFFER_FORWARD(copy_to_host)
11543 HALIDE_BUFFER_FORWARD(copy_to_device)
11544 HALIDE_BUFFER_FORWARD_CONST(has_device_allocation)
11545 HALIDE_BUFFER_FORWARD_CONST(host_dirty)
11546 HALIDE_BUFFER_FORWARD_CONST(device_dirty)
11547 HALIDE_BUFFER_FORWARD(set_host_dirty)
11548 HALIDE_BUFFER_FORWARD(set_device_dirty)
11549 HALIDE_BUFFER_FORWARD(device_sync)
11550 HALIDE_BUFFER_FORWARD(device_malloc)
11551 HALIDE_BUFFER_FORWARD(device_wrap_native)
11552 HALIDE_BUFFER_FORWARD(device_detach_native)
11553 HALIDE_BUFFER_FORWARD(allocate)
11554 HALIDE_BUFFER_FORWARD(deallocate)
11555 HALIDE_BUFFER_FORWARD(device_deallocate)
11556 HALIDE_BUFFER_FORWARD(device_free)
11557 HALIDE_BUFFER_FORWARD_CONST(all_equal)
11558
11559#undef HALIDE_BUFFER_FORWARD
11560#undef HALIDE_BUFFER_FORWARD_CONST
11561
11562 template<typename Fn, typename... Args>
11563 Buffer<T> &for_each_value(Fn &&f, Args... other_buffers) {
11564 get()->for_each_value(std::forward<Fn>(f), (*std::forward<Args>(other_buffers).get())...);
11565 return *this;
11566 }
11567
11568 template<typename Fn, typename... Args>
11569 const Buffer<T> &for_each_value(Fn &&f, Args... other_buffers) const {
11570 get()->for_each_value(std::forward<Fn>(f), (*std::forward<Args>(other_buffers).get())...);
11571 return *this;
11572 }
11573
11574 template<typename Fn>
11575 Buffer<T> &for_each_element(Fn &&f) {
11576 get()->for_each_element(std::forward<Fn>(f));
11577 return *this;
11578 }
11579
11580 template<typename Fn>
11581 const Buffer<T> &for_each_element(Fn &&f) const {
11582 get()->for_each_element(std::forward<Fn>(f));
11583 return *this;
11584 }
11585
11586 template<typename FnOrValue>
11587 Buffer<T> &fill(FnOrValue &&f) {
11588 get()->fill(std::forward<FnOrValue>(f));
11589 return *this;
11590 }
11591
11592 static constexpr bool has_static_halide_type = Runtime::Buffer<T>::has_static_halide_type;
11593
11594 static halide_type_t static_halide_type() {
11595 return Runtime::Buffer<T>::static_halide_type();
11596 }
11597
11598 template<typename T2>
11599 static bool can_convert_from(const Buffer<T2> &other) {
11600 return Halide::Runtime::Buffer<T>::can_convert_from(*other.get());
11601 }
11602
11603 // Note that since Runtime::Buffer stores halide_type_t rather than Halide::Type,
11604 // there is no handle-specific type information, so all handle types are
11605 // considered equivalent to void* here. (This only matters if you are making
11606 // a Buffer-of-handles, which is not really a real use case...)
11607 Type type() const {
11608 return contents->buf.type();
11609 }
11610
11611 template<typename T2>
11612 Buffer<T2> as() const {
11613 return Buffer<T2>(*this);
11614 }
11615
11616 Buffer<T> copy() const {
11617 return Buffer<T>(std::move(contents->buf.as<T>().copy()));
11618 }
11619
11620 template<typename T2>
11621 void copy_from(const Buffer<T2> &other) {
11622 contents->buf.copy_from(*other.get());
11623 }
11624
11625 template<typename... Args>
11626 auto operator()(int first, Args &&...args) -> decltype(std::declval<Runtime::Buffer<T>>()(first, std::forward<Args>(args)...)) {
11627 return (*get())(first, std::forward<Args>(args)...);
11628 }
11629
11630 template<typename... Args>
11631 auto operator()(int first, Args &&...args) const -> decltype(std::declval<const Runtime::Buffer<T>>()(first, std::forward<Args>(args)...)) {
11632 return (*get())(first, std::forward<Args>(args)...);
11633 }
11634
11635 auto operator()(const int *pos) -> decltype(std::declval<Runtime::Buffer<T>>()(pos)) {
11636 return (*get())(pos);
11637 }
11638
11639 auto operator()(const int *pos) const -> decltype(std::declval<const Runtime::Buffer<T>>()(pos)) {
11640 return (*get())(pos);
11641 }
11642
11643 auto operator()() -> decltype(std::declval<Runtime::Buffer<T>>()()) {
11644 return (*get())();
11645 }
11646
11647 auto operator()() const -> decltype(std::declval<const Runtime::Buffer<T>>()()) {
11648 return (*get())();
11649 }
11650 // @}
11651
11652 /** Make an Expr that loads from this concrete buffer at a computed coordinate. */
11653 // @{
11654 template<typename... Args>
11655 Expr operator()(const Expr &first, Args... rest) const {
11656 std::vector<Expr> args = {first, rest...};
11657 return (*this)(args);
11658 }
11659
11660 template<typename... Args>
11661 Expr operator()(const std::vector<Expr> &args) const {
11662 return buffer_accessor(Buffer<>(*this), args);
11663 }
11664 // @}
11665
11666 /** Copy to the GPU, using the device API that is the default for the given Target. */
11667 int copy_to_device(const Target &t = get_jit_target_from_environment()) {
11668 return copy_to_device(DeviceAPI::Default_GPU, t);
11669 }
11670
11671 /** Copy to the GPU, using the given device API */
11672 int copy_to_device(const DeviceAPI &d, const Target &t = get_jit_target_from_environment()) {
11673 return contents->buf.copy_to_device(get_device_interface_for_device_api(d, t, "Buffer::copy_to_device"));
11674 }
11675
11676 /** Allocate on the GPU, using the device API that is the default for the given Target. */
11677 int device_malloc(const Target &t = get_jit_target_from_environment()) {
11678 return device_malloc(DeviceAPI::Default_GPU, t);
11679 }
11680
11681 /** Allocate storage on the GPU, using the given device API */
11682 int device_malloc(const DeviceAPI &d, const Target &t = get_jit_target_from_environment()) {
11683 return contents->buf.device_malloc(get_device_interface_for_device_api(d, t, "Buffer::device_malloc"));
11684 }
11685
11686 /** Wrap a native handle, using the given device API.
11687 * It is a bad idea to pass DeviceAPI::Default_GPU to this routine
11688 * as the handle argument must match the API that the default
11689 * resolves to and it is clearer and more reliable to pass the
11690 * resolved DeviceAPI explicitly. */
11691 int device_wrap_native(const DeviceAPI &d, uint64_t handle, const Target &t = get_jit_target_from_environment()) {
11692 return contents->buf.device_wrap_native(get_device_interface_for_device_api(d, t, "Buffer::device_wrap_native"), handle);
11693 }
11694};
11695
11696} // namespace Halide
11697
11698#endif
11699#ifndef HALIDE_MODULUS_REMAINDER_H
11700#define HALIDE_MODULUS_REMAINDER_H
11701
11702/** \file
11703 * Routines for statically determining what expressions are divisible by.
11704 */
11705
11706#include <cstdint>
11707
11708namespace Halide {
11709
11710struct Expr;
11711
11712namespace Internal {
11713
11714template<typename T>
11715class Scope;
11716
11717/** The result of modulus_remainder analysis. These represent strided
11718 * subsets of the integers. A ModulusRemainder object m represents all
11719 * integers x such that there exists y such that x == m.modulus * y +
11720 * m.remainder. Note that under this definition a set containing a
11721 * single integer (a constant) is represented using a modulus of
11722 * zero. These sets can be combined with several mathematical
11723 * operators in the obvious way. E.g. m1 + m2 contains (at least) all
11724 * integers x1 + x2 such that x1 belongs to m1 and x2 belongs to
11725 * m2. These combinations are conservative. If some internal math
11726 * would overflow, it defaults to all of the integers (modulus == 1,
11727 * remainder == 0). */
11728
11729struct ModulusRemainder {
11730 ModulusRemainder() = default;
11731 ModulusRemainder(int64_t m, int64_t r)
11732 : modulus(m), remainder(r) {
11733 }
11734
11735 int64_t modulus = 1, remainder = 0;
11736
11737 // Take a conservatively-large union of two sets. Contains all
11738 // elements from both sets, and maybe some more stuff.
11739 static ModulusRemainder unify(const ModulusRemainder &a, const ModulusRemainder &b);
11740
11741 // Take a conservatively-large intersection. Everything in the
11742 // result is in at least one of the two sets, but not always both.
11743 static ModulusRemainder intersect(const ModulusRemainder &a, const ModulusRemainder &b);
11744
11745 bool operator==(const ModulusRemainder &other) const {
11746 return (modulus == other.modulus) && (remainder == other.remainder);
11747 }
11748};
11749
11750ModulusRemainder operator+(const ModulusRemainder &a, const ModulusRemainder &b);
11751ModulusRemainder operator-(const ModulusRemainder &a, const ModulusRemainder &b);
11752ModulusRemainder operator*(const ModulusRemainder &a, const ModulusRemainder &b);
11753ModulusRemainder operator/(const ModulusRemainder &a, const ModulusRemainder &b);
11754ModulusRemainder operator%(const ModulusRemainder &a, const ModulusRemainder &b);
11755
11756ModulusRemainder operator+(const ModulusRemainder &a, int64_t b);
11757ModulusRemainder operator-(const ModulusRemainder &a, int64_t b);
11758ModulusRemainder operator*(const ModulusRemainder &a, int64_t b);
11759ModulusRemainder operator/(const ModulusRemainder &a, int64_t b);
11760ModulusRemainder operator%(const ModulusRemainder &a, int64_t b);
11761
11762/** For things like alignment analysis, often it's helpful to know
11763 * if an integer expression is some multiple of a constant plus
11764 * some other constant. For example, it is straight-forward to
11765 * deduce that ((10*x + 2)*(6*y - 3) - 1) is congruent to five
11766 * modulo six.
11767 *
11768 * We get the most information when the modulus is large. E.g. if
11769 * something is congruent to 208 modulo 384, then we also know it's
11770 * congruent to 0 mod 8, and we can possibly use it as an index for an
11771 * aligned load. If all else fails, we can just say that an integer is
11772 * congruent to zero modulo one.
11773 */
11774ModulusRemainder modulus_remainder(const Expr &e);
11775
11776/** If we have alignment information about external variables, we can
11777 * let the analysis know about that using this version of
11778 * modulus_remainder: */
11779ModulusRemainder modulus_remainder(const Expr &e, const Scope<ModulusRemainder> &scope);
11780
11781/** Reduce an expression modulo some integer. Returns true and assigns
11782 * to remainder if an answer could be found. */
11783///@{
11784bool reduce_expr_modulo(const Expr &e, int64_t modulus, int64_t *remainder);
11785bool reduce_expr_modulo(const Expr &e, int64_t modulus, int64_t *remainder, const Scope<ModulusRemainder> &scope);
11786///@}
11787
11788void modulus_remainder_test();
11789
11790/** The greatest common divisor of two integers */
11791int64_t gcd(int64_t, int64_t);
11792
11793/** The least common multiple of two integers */
11794int64_t lcm(int64_t, int64_t);
11795
11796} // namespace Internal
11797} // namespace Halide
11798
11799#endif
11800#ifndef HALIDE_REDUCTION_H
11801#define HALIDE_REDUCTION_H
11802
11803/** \file
11804 * Defines internal classes related to Reduction Domains
11805 */
11806
11807
11808namespace Halide {
11809namespace Internal {
11810
11811class IRMutator;
11812
11813/** A single named dimension of a reduction domain */
11814struct ReductionVariable {
11815 std::string var;
11816 Expr min, extent;
11817
11818 /** This lets you use a ReductionVariable as a key in a map of the form
11819 * map<ReductionVariable, Foo, ReductionVariable::Compare> */
11820 struct Compare {
11821 bool operator()(const ReductionVariable &a, const ReductionVariable &b) const {
11822 return a.var < b.var;
11823 }
11824 };
11825};
11826
11827struct ReductionDomainContents;
11828
11829/** A reference-counted handle on a reduction domain, which is just a
11830 * vector of ReductionVariable. */
11831class ReductionDomain {
11832 IntrusivePtr<ReductionDomainContents> contents;
11833
11834public:
11835 /** This lets you use a ReductionDomain as a key in a map of the form
11836 * map<ReductionDomain, Foo, ReductionDomain::Compare> */
11837 struct Compare {
11838 bool operator()(const ReductionDomain &a, const ReductionDomain &b) const {
11839 internal_assert(a.contents.defined() && b.contents.defined());
11840 return a.contents < b.contents;
11841 }
11842 };
11843
11844 /** Construct a new nullptr reduction domain */
11845 ReductionDomain()
11846 : contents(nullptr) {
11847 }
11848
11849 /** Construct a reduction domain that spans the outer product of
11850 * all values of the given ReductionVariable in scanline order,
11851 * with the start of the vector being innermost, and the end of
11852 * the vector being outermost. */
11853 ReductionDomain(const std::vector<ReductionVariable> &domain);
11854
11855 /** Return a deep copy of this ReductionDomain. */
11856 ReductionDomain deep_copy() const;
11857
11858 /** Is this handle non-nullptr */
11859 bool defined() const {
11860 return contents.defined();
11861 }
11862
11863 /** Tests for equality of reference. Only one reduction domain is
11864 * allowed per reduction function, and this is used to verify
11865 * that */
11866 bool same_as(const ReductionDomain &other) const {
11867 return contents.same_as(other.contents);
11868 }
11869
11870 /** Immutable access to the reduction variables. */
11871 const std::vector<ReductionVariable> &domain() const;
11872
11873 /** Add predicate to the reduction domain. See \ref RDom::where
11874 * for more details. */
11875 void where(Expr predicate);
11876
11877 /** Return the predicate defined on this reducation demain. */
11878 Expr predicate() const;
11879
11880 /** Set the predicate, replacing any previously set predicate. */
11881 void set_predicate(const Expr &);
11882
11883 /** Split predicate into vector of ANDs. If there is no predicate (i.e. all
11884 * iteration domain in this reduction domain is valid), this returns an
11885 * empty vector. */
11886 std::vector<Expr> split_predicate() const;
11887
11888 /** Mark RDom as frozen, which means it cannot accept new predicates. An
11889 * RDom is frozen once it is used in a Func's update definition. */
11890 void freeze();
11891
11892 /** Check if a RDom has been frozen. If so, it is an error to add new
11893 * predicates. */
11894 bool frozen() const;
11895
11896 /** Pass an IRVisitor through to all Exprs referenced in the
11897 * ReductionDomain. */
11898 void accept(IRVisitor *) const;
11899
11900 /** Pass an IRMutator through to all Exprs referenced in the
11901 * ReductionDomain. */
11902 void mutate(IRMutator *);
11903};
11904
11905void split_predicate_test();
11906
11907} // namespace Internal
11908} // namespace Halide
11909
11910#endif
11911
11912namespace Halide {
11913namespace Internal {
11914
11915class Function;
11916
11917/** The actual IR nodes begin here. Remember that all the Expr
11918 * nodes also have a public "type" property */
11919
11920/** Cast a node from one type to another. Can't change vector widths. */
11921struct Cast : public ExprNode<Cast> {
11922 Expr value;
11923
11924 static Expr make(Type t, Expr v);
11925
11926 static const IRNodeType _node_type = IRNodeType::Cast;
11927};
11928
11929/** The sum of two expressions */
11930struct Add : public ExprNode<Add> {
11931 Expr a, b;
11932
11933 static Expr make(Expr a, Expr b);
11934
11935 static const IRNodeType _node_type = IRNodeType::Add;
11936};
11937
11938/** The difference of two expressions */
11939struct Sub : public ExprNode<Sub> {
11940 Expr a, b;
11941
11942 static Expr make(Expr a, Expr b);
11943
11944 static const IRNodeType _node_type = IRNodeType::Sub;
11945};
11946
11947/** The product of two expressions */
11948struct Mul : public ExprNode<Mul> {
11949 Expr a, b;
11950
11951 static Expr make(Expr a, Expr b);
11952
11953 static const IRNodeType _node_type = IRNodeType::Mul;
11954};
11955
11956/** The ratio of two expressions */
11957struct Div : public ExprNode<Div> {
11958 Expr a, b;
11959
11960 static Expr make(Expr a, Expr b);
11961
11962 static const IRNodeType _node_type = IRNodeType::Div;
11963};
11964
11965/** The remainder of a / b. Mostly equivalent to '%' in C, except that
11966 * the result here is always positive. For floats, this is equivalent
11967 * to calling fmod. */
11968struct Mod : public ExprNode<Mod> {
11969 Expr a, b;
11970
11971 static Expr make(Expr a, Expr b);
11972
11973 static const IRNodeType _node_type = IRNodeType::Mod;
11974};
11975
11976/** The lesser of two values. */
11977struct Min : public ExprNode<Min> {
11978 Expr a, b;
11979
11980 static Expr make(Expr a, Expr b);
11981
11982 static const IRNodeType _node_type = IRNodeType::Min;
11983};
11984
11985/** The greater of two values */
11986struct Max : public ExprNode<Max> {
11987 Expr a, b;
11988
11989 static Expr make(Expr a, Expr b);
11990
11991 static const IRNodeType _node_type = IRNodeType::Max;
11992};
11993
11994/** Is the first expression equal to the second */
11995struct EQ : public ExprNode<EQ> {
11996 Expr a, b;
11997
11998 static Expr make(Expr a, Expr b);
11999
12000 static const IRNodeType _node_type = IRNodeType::EQ;
12001};
12002
12003/** Is the first expression not equal to the second */
12004struct NE : public ExprNode<NE> {
12005 Expr a, b;
12006
12007 static Expr make(Expr a, Expr b);
12008
12009 static const IRNodeType _node_type = IRNodeType::NE;
12010};
12011
12012/** Is the first expression less than the second. */
12013struct LT : public ExprNode<LT> {
12014 Expr a, b;
12015
12016 static Expr make(Expr a, Expr b);
12017
12018 static const IRNodeType _node_type = IRNodeType::LT;
12019};
12020
12021/** Is the first expression less than or equal to the second. */
12022struct LE : public ExprNode<LE> {
12023 Expr a, b;
12024
12025 static Expr make(Expr a, Expr b);
12026
12027 static const IRNodeType _node_type = IRNodeType::LE;
12028};
12029
12030/** Is the first expression greater than the second. */
12031struct GT : public ExprNode<GT> {
12032 Expr a, b;
12033
12034 static Expr make(Expr a, Expr b);
12035
12036 static const IRNodeType _node_type = IRNodeType::GT;
12037};
12038
12039/** Is the first expression greater than or equal to the second. */
12040struct GE : public ExprNode<GE> {
12041 Expr a, b;
12042
12043 static Expr make(Expr a, Expr b);
12044
12045 static const IRNodeType _node_type = IRNodeType::GE;
12046};
12047
12048/** Logical and - are both expressions true */
12049struct And : public ExprNode<And> {
12050 Expr a, b;
12051
12052 static Expr make(Expr a, Expr b);
12053
12054 static const IRNodeType _node_type = IRNodeType::And;
12055};
12056
12057/** Logical or - is at least one of the expression true */
12058struct Or : public ExprNode<Or> {
12059 Expr a, b;
12060
12061 static Expr make(Expr a, Expr b);
12062
12063 static const IRNodeType _node_type = IRNodeType::Or;
12064};
12065
12066/** Logical not - true if the expression false */
12067struct Not : public ExprNode<Not> {
12068 Expr a;
12069
12070 static Expr make(Expr a);
12071
12072 static const IRNodeType _node_type = IRNodeType::Not;
12073};
12074
12075/** A ternary operator. Evalutes 'true_value' and 'false_value',
12076 * then selects between them based on 'condition'. Equivalent to
12077 * the ternary operator in C. */
12078struct Select : public ExprNode<Select> {
12079 Expr condition, true_value, false_value;
12080
12081 static Expr make(Expr condition, Expr true_value, Expr false_value);
12082
12083 static const IRNodeType _node_type = IRNodeType::Select;
12084};
12085
12086/** Load a value from a named symbol if predicate is true. The buffer
12087 * is treated as an array of the 'type' of this Load node. That is,
12088 * the buffer has no inherent type. The name may be the name of an
12089 * enclosing allocation, an input or output buffer, or any other
12090 * symbol of type Handle(). */
12091struct Load : public ExprNode<Load> {
12092 std::string name;
12093
12094 Expr predicate, index;
12095
12096 // If it's a load from an image argument or compiled-in constant
12097 // image, this will point to that
12098 Buffer<> image;
12099
12100 // If it's a load from an image parameter, this points to that
12101 Parameter param;
12102
12103 // The alignment of the index. If the index is a vector, this is
12104 // the alignment of the first lane.
12105 ModulusRemainder alignment;
12106
12107 static Expr make(Type type, const std::string &name,
12108 Expr index, Buffer<> image,
12109 Parameter param,
12110 Expr predicate,
12111 ModulusRemainder alignment);
12112
12113 static const IRNodeType _node_type = IRNodeType::Load;
12114};
12115
12116/** A linear ramp vector node. This is vector with 'lanes' elements,
12117 * where element i is 'base' + i*'stride'. This is a convenient way to
12118 * pass around vectors without busting them up into individual
12119 * elements. E.g. a dense vector load from a buffer can use a ramp
12120 * node with stride 1 as the index. */
12121struct Ramp : public ExprNode<Ramp> {
12122 Expr base, stride;
12123 int lanes;
12124
12125 static Expr make(Expr base, Expr stride, int lanes);
12126
12127 static const IRNodeType _node_type = IRNodeType::Ramp;
12128};
12129
12130/** A vector with 'lanes' elements, in which every element is
12131 * 'value'. This is a special case of the ramp node above, in which
12132 * the stride is zero. */
12133struct Broadcast : public ExprNode<Broadcast> {
12134 Expr value;
12135 int lanes;
12136
12137 static Expr make(Expr value, int lanes);
12138
12139 static const IRNodeType _node_type = IRNodeType::Broadcast;
12140};
12141
12142/** A let expression, like you might find in a functional
12143 * language. Within the expression \ref Let::body, instances of the Var
12144 * node \ref Let::name refer to \ref Let::value. */
12145struct Let : public ExprNode<Let> {
12146 std::string name;
12147 Expr value, body;
12148
12149 static Expr make(const std::string &name, Expr value, Expr body);
12150
12151 static const IRNodeType _node_type = IRNodeType::Let;
12152};
12153
12154/** The statement form of a let node. Within the statement 'body',
12155 * instances of the Var named 'name' refer to 'value' */
12156struct LetStmt : public StmtNode<LetStmt> {
12157 std::string name;
12158 Expr value;
12159 Stmt body;
12160
12161 static Stmt make(const std::string &name, Expr value, Stmt body);
12162
12163 static const IRNodeType _node_type = IRNodeType::LetStmt;
12164};
12165
12166/** If the 'condition' is false, then evaluate and return the message,
12167 * which should be a call to an error function. */
12168struct AssertStmt : public StmtNode<AssertStmt> {
12169 // if condition then val else error out with message
12170 Expr condition;
12171 Expr message;
12172
12173 static Stmt make(Expr condition, Expr message);
12174
12175 static const IRNodeType _node_type = IRNodeType::AssertStmt;
12176};
12177
12178/** This node is a helpful annotation to do with permissions. If 'is_produce' is
12179 * set to true, this represents a producer node which may also contain updates;
12180 * otherwise, this represents a consumer node. If the producer node contains
12181 * updates, the body of the node will be a block of 'produce' and 'update'
12182 * in that order. In a producer node, the access is read-write only (or write
12183 * only if it doesn't have updates). In a consumer node, the access is read-only.
12184 * None of this is actually enforced, the node is purely for informative purposes
12185 * to help out our analysis during lowering. For every unique ProducerConsumer,
12186 * there is an associated Realize node with the same name that creates the buffer
12187 * being read from or written to in the body of the ProducerConsumer.
12188 */
12189struct ProducerConsumer : public StmtNode<ProducerConsumer> {
12190 std::string name;
12191 bool is_producer;
12192 Stmt body;
12193
12194 static Stmt make(const std::string &name, bool is_producer, Stmt body);
12195
12196 static Stmt make_produce(const std::string &name, Stmt body);
12197 static Stmt make_consume(const std::string &name, Stmt body);
12198
12199 static const IRNodeType _node_type = IRNodeType::ProducerConsumer;
12200};
12201
12202/** Store a 'value' to the buffer called 'name' at a given 'index' if
12203 * 'predicate' is true. The buffer is interpreted as an array of the
12204 * same type as 'value'. The name may be the name of an enclosing
12205 * Allocate node, an output buffer, or any other symbol of type
12206 * Handle(). */
12207struct Store : public StmtNode<Store> {
12208 std::string name;
12209 Expr predicate, value, index;
12210 // If it's a store to an output buffer, then this parameter points to it.
12211 Parameter param;
12212
12213 // The alignment of the index. If the index is a vector, this is
12214 // the alignment of the first lane.
12215 ModulusRemainder alignment;
12216
12217 static Stmt make(const std::string &name, Expr value, Expr index,
12218 Parameter param, Expr predicate, ModulusRemainder alignment);
12219
12220 static const IRNodeType _node_type = IRNodeType::Store;
12221};
12222
12223/** This defines the value of a function at a multi-dimensional
12224 * location. You should think of it as a store to a multi-dimensional
12225 * array. It gets lowered to a conventional Store node. The name must
12226 * correspond to an output buffer or the name of an enclosing Realize
12227 * node. */
12228struct Provide : public StmtNode<Provide> {
12229 std::string name;
12230 std::vector<Expr> values;
12231 std::vector<Expr> args;
12232
12233 static Stmt make(const std::string &name, const std::vector<Expr> &values, const std::vector<Expr> &args);
12234
12235 static const IRNodeType _node_type = IRNodeType::Provide;
12236};
12237
12238/** Allocate a scratch area called with the given name, type, and
12239 * size. The buffer lives for at most the duration of the body
12240 * statement, within which it may or may not be freed explicitly with
12241 * a Free node with a matching name. Allocation only occurs if the
12242 * condition evaluates to true. Within the body of the allocation,
12243 * defines a symbol with the given name and the type Handle(). */
12244struct Allocate : public StmtNode<Allocate> {
12245 std::string name;
12246 Type type;
12247 MemoryType memory_type;
12248 std::vector<Expr> extents;
12249 Expr condition;
12250
12251 // These override the code generator dependent malloc and free
12252 // equivalents if provided. If the new_expr succeeds, that is it
12253 // returns non-nullptr, the function named be free_function is
12254 // guaranteed to be called. The free function signature must match
12255 // that of the code generator dependent free (typically
12256 // halide_free). If free_function is left empty, code generator
12257 // default will be called.
12258 Expr new_expr;
12259 std::string free_function;
12260
12261 Stmt body;
12262
12263 static Stmt make(const std::string &name, Type type, MemoryType memory_type,
12264 const std::vector<Expr> &extents,
12265 Expr condition, Stmt body,
12266 Expr new_expr = Expr(), const std::string &free_function = std::string());
12267
12268 /** A routine to check if the extents are all constants, and if so verify
12269 * the total size is less than 2^31 - 1. If the result is constant, but
12270 * overflows, this routine asserts. This returns 0 if the extents are
12271 * not all constants; otherwise, it returns the total constant allocation
12272 * size. */
12273 static int32_t constant_allocation_size(const std::vector<Expr> &extents, const std::string &name);
12274 int32_t constant_allocation_size() const;
12275
12276 static const IRNodeType _node_type = IRNodeType::Allocate;
12277};
12278
12279/** Free the resources associated with the given buffer. */
12280struct Free : public StmtNode<Free> {
12281 std::string name;
12282
12283 static Stmt make(const std::string &name);
12284
12285 static const IRNodeType _node_type = IRNodeType::Free;
12286};
12287
12288/** Allocate a multi-dimensional buffer of the given type and
12289 * size. Create some scratch memory that will back the function 'name'
12290 * over the range specified in 'bounds'. The bounds are a vector of
12291 * (min, extent) pairs for each dimension. Allocation only occurs if
12292 * the condition evaluates to true.
12293 */
12294struct Realize : public StmtNode<Realize> {
12295 std::string name;
12296 std::vector<Type> types;
12297 MemoryType memory_type;
12298 Region bounds;
12299 Expr condition;
12300 Stmt body;
12301
12302 static Stmt make(const std::string &name, const std::vector<Type> &types, MemoryType memory_type, const Region &bounds, Expr condition, Stmt body);
12303
12304 static const IRNodeType _node_type = IRNodeType::Realize;
12305};
12306
12307/** A sequence of statements to be executed in-order. 'rest' may be
12308 * undefined. Used rest.defined() to find out. */
12309struct Block : public StmtNode<Block> {
12310 Stmt first, rest;
12311
12312 static Stmt make(Stmt first, Stmt rest);
12313 /** Construct zero or more Blocks to invoke a list of statements in order.
12314 * This method may not return a Block statement if stmts.size() <= 1. */
12315 static Stmt make(const std::vector<Stmt> &stmts);
12316
12317 static const IRNodeType _node_type = IRNodeType::Block;
12318};
12319
12320/** A pair of statements executed concurrently. Both statements are
12321 * joined before the Stmt ends. This is the parallel equivalent to
12322 * Block. */
12323struct Fork : public StmtNode<Fork> {
12324 Stmt first, rest;
12325
12326 static Stmt make(Stmt first, Stmt rest);
12327
12328 static const IRNodeType _node_type = IRNodeType::Fork;
12329};
12330
12331/** An if-then-else block. 'else' may be undefined. */
12332struct IfThenElse : public StmtNode<IfThenElse> {
12333 Expr condition;
12334 Stmt then_case, else_case;
12335
12336 static Stmt make(Expr condition, Stmt then_case, Stmt else_case = Stmt());
12337
12338 static const IRNodeType _node_type = IRNodeType::IfThenElse;
12339};
12340
12341/** Evaluate and discard an expression, presumably because it has some side-effect. */
12342struct Evaluate : public StmtNode<Evaluate> {
12343 Expr value;
12344
12345 static Stmt make(Expr v);
12346
12347 static const IRNodeType _node_type = IRNodeType::Evaluate;
12348};
12349
12350/** A function call. This can represent a call to some extern function
12351 * (like sin), but it's also our multi-dimensional version of a Load,
12352 * so it can be a load from an input image, or a call to another
12353 * halide function. These two types of call nodes don't survive all
12354 * the way down to code generation - the lowering process converts
12355 * them to Load nodes. */
12356struct Call : public ExprNode<Call> {
12357 std::string name;
12358 std::vector<Expr> args;
12359 typedef enum { Image, ///< A load from an input image
12360 Extern, ///< A call to an external C-ABI function, possibly with side-effects
12361 ExternCPlusPlus, ///< A call to an external C-ABI function, possibly with side-effects
12362 PureExtern, ///< A call to a guaranteed-side-effect-free external function
12363 Halide, ///< A call to a Func
12364 Intrinsic, ///< A possibly-side-effecty compiler intrinsic, which has special handling during codegen
12365 PureIntrinsic ///< A side-effect-free version of the above.
12366 } CallType;
12367 CallType call_type;
12368
12369 // Halide uses calls internally to represent certain operations
12370 // (instead of IR nodes). These are matched by name. Note that
12371 // these are deliberately char* (rather than std::string) so that
12372 // they can be referenced at static-initialization time without
12373 // risking ambiguous initalization order; we use a typedef to simplify
12374 // declaration.
12375 typedef const char *const ConstString;
12376
12377 // enums for various well-known intrinsics. (It is not *required* that all
12378 // intrinsics have an enum entry here, but as a matter of style, it is recommended.)
12379 // Note that these are only used in the API; inside the node, they are translated
12380 // into a name. (To recover the name, call get_intrinsic_name().)
12381 //
12382 // Please keep this list sorted alphabetically; the specific enum values
12383 // are *not* guaranteed to be stable across time.
12384 enum IntrinsicOp {
12385 abs,
12386 absd,
12387 add_image_checks_marker,
12388 alloca,
12389 bitwise_and,
12390 bitwise_not,
12391 bitwise_or,
12392 bitwise_xor,
12393 bool_to_mask,
12394 bundle, // Bundle multiple exprs together temporarily for analysis (e.g. CSE)
12395 call_cached_indirect_function,
12396 cast_mask,
12397 count_leading_zeros,
12398 count_trailing_zeros,
12399 declare_box_touched,
12400 debug_to_file,
12401 div_round_to_zero,
12402 dynamic_shuffle,
12403 extract_mask_element,
12404 gpu_thread_barrier,
12405 halving_add,
12406 halving_sub,
12407 hvx_gather,
12408 hvx_scatter,
12409 hvx_scatter_acc,
12410 hvx_scatter_release,
12411 if_then_else,
12412 if_then_else_mask,
12413 image_load,
12414 image_store,
12415 lerp,
12416 likely,
12417 likely_if_innermost,
12418 make_struct,
12419 memoize_expr,
12420 mod_round_to_zero,
12421 mul_shift_right,
12422 mux,
12423 popcount,
12424 predicate,
12425 prefetch,
12426 promise_clamped,
12427 random,
12428 register_destructor,
12429 reinterpret,
12430 require,
12431 require_mask,
12432 return_second,
12433 rewrite_buffer,
12434 rounding_halving_add,
12435 rounding_halving_sub,
12436 rounding_mul_shift_right,
12437 rounding_shift_left,
12438 rounding_shift_right,
12439 saturating_add,
12440 saturating_sub,
12441 scatter_gather,
12442 select_mask,
12443 shift_left,
12444 shift_right,
12445 signed_integer_overflow,
12446 size_of_halide_buffer_t,
12447 sorted_avg, // Compute (arg[0] + arg[1]) / 2, assuming arg[0] < arg[1].
12448 strict_float,
12449 stringify,
12450 undef,
12451 unreachable,
12452 unsafe_promise_clamped,
12453 widening_add,
12454 widening_mul,
12455 widening_shift_left,
12456 widening_shift_right,
12457 widening_sub,
12458 IntrinsicOpCount // Sentinel: keep last.
12459 };
12460
12461 static const char *get_intrinsic_name(IntrinsicOp op);
12462
12463 // We also declare some symbolic names for some of the runtime
12464 // functions that we want to construct Call nodes to here to avoid
12465 // magic string constants and the potential risk of typos.
12466 HALIDE_EXPORT static ConstString
12467 buffer_get_dimensions,
12468 buffer_get_min,
12469 buffer_get_extent,
12470 buffer_get_stride,
12471 buffer_get_max,
12472 buffer_get_host,
12473 buffer_get_device,
12474 buffer_get_device_interface,
12475 buffer_get_shape,
12476 buffer_get_host_dirty,
12477 buffer_get_device_dirty,
12478 buffer_get_type,
12479 buffer_set_host_dirty,
12480 buffer_set_device_dirty,
12481 buffer_is_bounds_query,
12482 buffer_init,
12483 buffer_init_from_buffer,
12484 buffer_crop,
12485 buffer_set_bounds,
12486 trace;
12487
12488 // If it's a call to another halide function, this call node holds
12489 // a possibly-weak reference to that function.
12490 FunctionPtr func;
12491
12492 // If that function has multiple values, which value does this
12493 // call node refer to?
12494 int value_index;
12495
12496 // If it's a call to an image, this call nodes hold a
12497 // pointer to that image's buffer
12498 Buffer<> image;
12499
12500 // If it's a call to an image parameter, this call node holds a
12501 // pointer to that
12502 Parameter param;
12503
12504 static Expr make(Type type, IntrinsicOp op, const std::vector<Expr> &args, CallType call_type,
12505 FunctionPtr func = FunctionPtr(), int value_index = 0,
12506 const Buffer<> &image = Buffer<>(), Parameter param = Parameter());
12507
12508 static Expr make(Type type, const std::string &name, const std::vector<Expr> &args, CallType call_type,
12509 FunctionPtr func = FunctionPtr(), int value_index = 0,
12510 Buffer<> image = Buffer<>(), Parameter param = Parameter());
12511
12512 /** Convenience constructor for calls to other halide functions */
12513 static Expr make(const Function &func, const std::vector<Expr> &args, int idx = 0);
12514
12515 /** Convenience constructor for loads from concrete images */
12516 static Expr make(const Buffer<> &image, const std::vector<Expr> &args) {
12517 return make(image.type(), image.name(), args, Image, FunctionPtr(), 0, image, Parameter());
12518 }
12519
12520 /** Convenience constructor for loads from images parameters */
12521 static Expr make(const Parameter &param, const std::vector<Expr> &args) {
12522 return make(param.type(), param.name(), args, Image, FunctionPtr(), 0, Buffer<>(), param);
12523 }
12524
12525 /** Check if a call node is pure within a pipeline, meaning that
12526 * the same args always give the same result, and the calls can be
12527 * reordered, duplicated, unified, etc without changing the
12528 * meaning of anything. Not transitive - doesn't guarantee the
12529 * args themselves are pure. An example of a pure Call node is
12530 * sqrt. If in doubt, don't mark a Call node as pure. */
12531 bool is_pure() const {
12532 return (call_type == PureExtern ||
12533 call_type == Image ||
12534 call_type == PureIntrinsic);
12535 }
12536
12537 bool is_intrinsic() const {
12538 return (call_type == Intrinsic ||
12539 call_type == PureIntrinsic);
12540 }
12541
12542 bool is_intrinsic(IntrinsicOp op) const {
12543 return is_intrinsic() && this->name == get_intrinsic_name(op);
12544 }
12545
12546 /** Returns a pointer to a call node if the expression is a call to
12547 * one of the requested intrinsics. */
12548 static const Call *as_intrinsic(const Expr &e, std::initializer_list<IntrinsicOp> intrinsics) {
12549 if (const Call *c = e.as<Call>()) {
12550 for (IntrinsicOp i : intrinsics) {
12551 if (c->is_intrinsic(i)) {
12552 return c;
12553 }
12554 }
12555 }
12556 return nullptr;
12557 }
12558
12559 static const Call *as_tag(const Expr &e) {
12560 return as_intrinsic(e, {Call::likely, Call::likely_if_innermost, Call::predicate, Call::strict_float});
12561 }
12562
12563 bool is_extern() const {
12564 return (call_type == Extern ||
12565 call_type == ExternCPlusPlus ||
12566 call_type == PureExtern);
12567 }
12568
12569 static const IRNodeType _node_type = IRNodeType::Call;
12570};
12571
12572/** A named variable. Might be a loop variable, function argument,
12573 * parameter, reduction variable, or something defined by a Let or
12574 * LetStmt node. */
12575struct Variable : public ExprNode<Variable> {
12576 std::string name;
12577
12578 /** References to scalar parameters, or to the dimensions of buffer
12579 * parameters hang onto those expressions. */
12580 Parameter param;
12581
12582 /** References to properties of literal image parameters. */
12583 Buffer<> image;
12584
12585 /** Reduction variables hang onto their domains */
12586 ReductionDomain reduction_domain;
12587
12588 static Expr make(Type type, const std::string &name) {
12589 return make(type, name, Buffer<>(), Parameter(), ReductionDomain());
12590 }
12591
12592 static Expr make(Type type, const std::string &name, Parameter param) {
12593 return make(type, name, Buffer<>(), std::move(param), ReductionDomain());
12594 }
12595
12596 static Expr make(Type type, const std::string &name, const Buffer<> &image) {
12597 return make(type, name, image, Parameter(), ReductionDomain());
12598 }
12599
12600 static Expr make(Type type, const std::string &name, ReductionDomain reduction_domain) {
12601 return make(type, name, Buffer<>(), Parameter(), std::move(reduction_domain));
12602 }
12603
12604 static Expr make(Type type, const std::string &name, Buffer<> image,
12605 Parameter param, ReductionDomain reduction_domain);
12606
12607 static const IRNodeType _node_type = IRNodeType::Variable;
12608};
12609
12610/** A for loop. Execute the 'body' statement for all values of the
12611 * variable 'name' from 'min' to 'min + extent'. There are four
12612 * types of For nodes. A 'Serial' for loop is a conventional
12613 * one. In a 'Parallel' for loop, each iteration of the loop
12614 * happens in parallel or in some unspecified order. In a
12615 * 'Vectorized' for loop, each iteration maps to one SIMD lane,
12616 * and the whole loop is executed in one shot. For this case,
12617 * 'extent' must be some small integer constant (probably 4, 8, or
12618 * 16). An 'Unrolled' for loop compiles to a completely unrolled
12619 * version of the loop. Each iteration becomes its own
12620 * statement. Again in this case, 'extent' should be a small
12621 * integer constant. */
12622struct For : public StmtNode<For> {
12623 std::string name;
12624 Expr min, extent;
12625 ForType for_type;
12626 DeviceAPI device_api;
12627 Stmt body;
12628
12629 static Stmt make(const std::string &name, Expr min, Expr extent, ForType for_type, DeviceAPI device_api, Stmt body);
12630
12631 bool is_unordered_parallel() const {
12632 return Halide::Internal::is_unordered_parallel(for_type);
12633 }
12634 bool is_parallel() const {
12635 return Halide::Internal::is_parallel(for_type);
12636 }
12637
12638 static const IRNodeType _node_type = IRNodeType::For;
12639};
12640
12641struct Acquire : public StmtNode<Acquire> {
12642 Expr semaphore;
12643 Expr count;
12644 Stmt body;
12645
12646 static Stmt make(Expr semaphore, Expr count, Stmt body);
12647
12648 static const IRNodeType _node_type = IRNodeType::Acquire;
12649};
12650
12651/** Construct a new vector by taking elements from another sequence of
12652 * vectors. */
12653struct Shuffle : public ExprNode<Shuffle> {
12654 std::vector<Expr> vectors;
12655
12656 /** Indices indicating which vector element to place into the
12657 * result. The elements are numbered by their position in the
12658 * concatenation of the vector arguments. */
12659 std::vector<int> indices;
12660
12661 static Expr make(const std::vector<Expr> &vectors,
12662 const std::vector<int> &indices);
12663
12664 /** Convenience constructor for making a shuffle representing an
12665 * interleaving of vectors of the same length. */
12666 static Expr make_interleave(const std::vector<Expr> &vectors);
12667
12668 /** Convenience constructor for making a shuffle representing a
12669 * concatenation of the vectors. */
12670 static Expr make_concat(const std::vector<Expr> &vectors);
12671
12672 /** Convenience constructor for making a shuffle representing a
12673 * broadcast of a vector. */
12674 static Expr make_broadcast(Expr vector, int factor);
12675
12676 /** Convenience constructor for making a shuffle representing a
12677 * contiguous subset of a vector. */
12678 static Expr make_slice(Expr vector, int begin, int stride, int size);
12679
12680 /** Convenience constructor for making a shuffle representing
12681 * extracting a single element. */
12682 static Expr make_extract_element(Expr vector, int i);
12683
12684 /** Check if this shuffle is an interleaving of the vector
12685 * arguments. */
12686 bool is_interleave() const;
12687
12688 /** Check if this shuffle can be represented as a broadcast.
12689 * For example:
12690 * A uint8 shuffle of with 4*n lanes and indices:
12691 * 0, 1, 2, 3, 0, 1, 2, 3, ....., 0, 1, 2, 3
12692 * can be represented as a uint32 broadcast with n lanes (factor = 4). */
12693 bool is_broadcast() const;
12694 int broadcast_factor() const;
12695
12696 /** Check if this shuffle is a concatenation of the vector
12697 * arguments. */
12698 bool is_concat() const;
12699
12700 /** Check if this shuffle is a contiguous strict subset of the
12701 * vector arguments, and if so, the offset and stride of the
12702 * slice. */
12703 ///@{
12704 bool is_slice() const;
12705 int slice_begin() const {
12706 return indices[0];
12707 }
12708 int slice_stride() const {
12709 return indices.size() >= 2 ? indices[1] - indices[0] : 1;
12710 }
12711 ///@}
12712
12713 /** Check if this shuffle is extracting a scalar from the vector
12714 * arguments. */
12715 bool is_extract_element() const;
12716
12717 static const IRNodeType _node_type = IRNodeType::Shuffle;
12718};
12719
12720/** Represent a multi-dimensional region of a Func or an ImageParam that
12721 * needs to be prefetched. */
12722struct Prefetch : public StmtNode<Prefetch> {
12723 std::string name;
12724 std::vector<Type> types;
12725 Region bounds;
12726 PrefetchDirective prefetch;
12727 Expr condition;
12728
12729 Stmt body;
12730
12731 static Stmt make(const std::string &name, const std::vector<Type> &types,
12732 const Region &bounds,
12733 const PrefetchDirective &prefetch,
12734 Expr condition, Stmt body);
12735
12736 static const IRNodeType _node_type = IRNodeType::Prefetch;
12737};
12738
12739/** Lock all the Store nodes in the body statement.
12740 * Typically the lock is implemented by an atomic operation
12741 * (e.g. atomic add or atomic compare-and-swap).
12742 * However, if necessary, the node can access a mutex buffer through
12743 * mutex_name and mutex_args, by lowering this node into
12744 * calls to acquire and release the lock. */
12745struct Atomic : public StmtNode<Atomic> {
12746 std::string producer_name;
12747 std::string mutex_name; // empty string if not using mutex
12748 Stmt body;
12749
12750 static Stmt make(const std::string &producer_name,
12751 const std::string &mutex_name,
12752 Stmt body);
12753
12754 static const IRNodeType _node_type = IRNodeType::Atomic;
12755};
12756
12757/** Horizontally reduce a vector to a scalar or narrower vector using
12758 * the given commutative and associative binary operator. The reduction
12759 * factor is dictated by the number of lanes in the input and output
12760 * types. Groups of adjacent lanes are combined. The number of lanes
12761 * in the input type must be a divisor of the number of lanes of the
12762 * output type. */
12763struct VectorReduce : public ExprNode<VectorReduce> {
12764 // 99.9% of the time people will use this for horizontal addition,
12765 // but these are all of our commutative and associative primitive
12766 // operators.
12767 typedef enum {
12768 Add,
12769 SaturatingAdd,
12770 Mul,
12771 Min,
12772 Max,
12773 And,
12774 Or,
12775 } Operator;
12776
12777 Expr value;
12778 Operator op;
12779
12780 static Expr make(Operator op, Expr vec, int lanes);
12781
12782 static const IRNodeType _node_type = IRNodeType::VectorReduce;
12783};
12784
12785} // namespace Internal
12786} // namespace Halide
12787
12788#endif
12789
12790/** \file
12791 * Defines the base class for things that recursively walk over the IR
12792 */
12793
12794namespace Halide {
12795namespace Internal {
12796
12797/** A base class for algorithms that need to recursively walk over the
12798 * IR. The default implementations just recursively walk over the
12799 * children. Override the ones you care about.
12800 */
12801class IRVisitor {
12802public:
12803 IRVisitor() = default;
12804 virtual ~IRVisitor() = default;
12805
12806protected:
12807 // ExprNode<> and StmtNode<> are allowed to call visit (to implement accept())
12808 template<typename T>
12809 friend struct ExprNode;
12810
12811 template<typename T>
12812 friend struct StmtNode;
12813
12814 virtual void visit(const IntImm *);
12815 virtual void visit(const UIntImm *);
12816 virtual void visit(const FloatImm *);
12817 virtual void visit(const StringImm *);
12818 virtual void visit(const Cast *);
12819 virtual void visit(const Variable *);
12820 virtual void visit(const Add *);
12821 virtual void visit(const Sub *);
12822 virtual void visit(const Mul *);
12823 virtual void visit(const Div *);
12824 virtual void visit(const Mod *);
12825 virtual void visit(const Min *);
12826 virtual void visit(const Max *);
12827 virtual void visit(const EQ *);
12828 virtual void visit(const NE *);
12829 virtual void visit(const LT *);
12830 virtual void visit(const LE *);
12831 virtual void visit(const GT *);
12832 virtual void visit(const GE *);
12833 virtual void visit(const And *);
12834 virtual void visit(const Or *);
12835 virtual void visit(const Not *);
12836 virtual void visit(const Select *);
12837 virtual void visit(const Load *);
12838 virtual void visit(const Ramp *);
12839 virtual void visit(const Broadcast *);
12840 virtual void visit(const Call *);
12841 virtual void visit(const Let *);
12842 virtual void visit(const LetStmt *);
12843 virtual void visit(const AssertStmt *);
12844 virtual void visit(const ProducerConsumer *);
12845 virtual void visit(const For *);
12846 virtual void visit(const Store *);
12847 virtual void visit(const Provide *);
12848 virtual void visit(const Allocate *);
12849 virtual void visit(const Free *);
12850 virtual void visit(const Realize *);
12851 virtual void visit(const Block *);
12852 virtual void visit(const IfThenElse *);
12853 virtual void visit(const Evaluate *);
12854 virtual void visit(const Shuffle *);
12855 virtual void visit(const VectorReduce *);
12856 virtual void visit(const Prefetch *);
12857 virtual void visit(const Fork *);
12858 virtual void visit(const Acquire *);
12859 virtual void visit(const Atomic *);
12860};
12861
12862/** A base class for algorithms that walk recursively over the IR
12863 * without visiting the same node twice. This is for passes that are
12864 * capable of interpreting the IR as a DAG instead of a tree. */
12865class IRGraphVisitor : public IRVisitor {
12866protected:
12867 /** By default these methods add the node to the visited set, and
12868 * return whether or not it was already there. If it wasn't there,
12869 * it delegates to the appropriate visit method. You can override
12870 * them if you like. */
12871 // @{
12872 virtual void include(const Expr &);
12873 virtual void include(const Stmt &);
12874 // @}
12875
12876private:
12877 /** The nodes visited so far */
12878 std::set<IRHandle> visited;
12879
12880protected:
12881 /** These methods should call 'include' on the children to only
12882 * visit them if they haven't been visited already. */
12883 // @{
12884 void visit(const IntImm *) override;
12885 void visit(const UIntImm *) override;
12886 void visit(const FloatImm *) override;
12887 void visit(const StringImm *) override;
12888 void visit(const Cast *) override;
12889 void visit(const Variable *) override;
12890 void visit(const Add *) override;
12891 void visit(const Sub *) override;
12892 void visit(const Mul *) override;
12893 void visit(const Div *) override;
12894 void visit(const Mod *) override;
12895 void visit(const Min *) override;
12896 void visit(const Max *) override;
12897 void visit(const EQ *) override;
12898 void visit(const NE *) override;
12899 void visit(const LT *) override;
12900 void visit(const LE *) override;
12901 void visit(const GT *) override;
12902 void visit(const GE *) override;
12903 void visit(const And *) override;
12904 void visit(const Or *) override;
12905 void visit(const Not *) override;
12906 void visit(const Select *) override;
12907 void visit(const Load *) override;
12908 void visit(const Ramp *) override;
12909 void visit(const Broadcast *) override;
12910 void visit(const Call *) override;
12911 void visit(const Let *) override;
12912 void visit(const LetStmt *) override;
12913 void visit(const AssertStmt *) override;
12914 void visit(const ProducerConsumer *) override;
12915 void visit(const For *) override;
12916 void visit(const Store *) override;
12917 void visit(const Provide *) override;
12918 void visit(const Allocate *) override;
12919 void visit(const Free *) override;
12920 void visit(const Realize *) override;
12921 void visit(const Block *) override;
12922 void visit(const IfThenElse *) override;
12923 void visit(const Evaluate *) override;
12924 void visit(const Shuffle *) override;
12925 void visit(const VectorReduce *) override;
12926 void visit(const Prefetch *) override;
12927 void visit(const Acquire *) override;
12928 void visit(const Fork *) override;
12929 void visit(const Atomic *) override;
12930 // @}
12931};
12932
12933/** A visitor/mutator capable of passing arbitrary arguments to the
12934 * visit methods using CRTP and returning any types from them. All
12935 * Expr visitors must have the same signature, and all Stmt visitors
12936 * must have the same signature. Does not have default implementations
12937 * of the visit methods. */
12938template<typename T, typename ExprRet, typename StmtRet>
12939class VariadicVisitor {
12940private:
12941 template<typename... Args>
12942 ExprRet dispatch_expr(const BaseExprNode *node, Args &&...args) {
12943 if (node == nullptr) {
12944 return ExprRet{};
12945 }
12946 switch (node->node_type) {
12947 case IRNodeType::IntImm:
12948 return ((T *)this)->visit((const IntImm *)node, std::forward<Args>(args)...);
12949 case IRNodeType::UIntImm:
12950 return ((T *)this)->visit((const UIntImm *)node, std::forward<Args>(args)...);
12951 case IRNodeType::FloatImm:
12952 return ((T *)this)->visit((const FloatImm *)node, std::forward<Args>(args)...);
12953 case IRNodeType::StringImm:
12954 return ((T *)this)->visit((const StringImm *)node, std::forward<Args>(args)...);
12955 case IRNodeType::Broadcast:
12956 return ((T *)this)->visit((const Broadcast *)node, std::forward<Args>(args)...);
12957 case IRNodeType::Cast:
12958 return ((T *)this)->visit((const Cast *)node, std::forward<Args>(args)...);
12959 case IRNodeType::Variable:
12960 return ((T *)this)->visit((const Variable *)node, std::forward<Args>(args)...);
12961 case IRNodeType::Add:
12962 return ((T *)this)->visit((const Add *)node, std::forward<Args>(args)...);
12963 case IRNodeType::Sub:
12964 return ((T *)this)->visit((const Sub *)node, std::forward<Args>(args)...);
12965 case IRNodeType::Mod:
12966 return ((T *)this)->visit((const Mod *)node, std::forward<Args>(args)...);
12967 case IRNodeType::Mul:
12968 return ((T *)this)->visit((const Mul *)node, std::forward<Args>(args)...);
12969 case IRNodeType::Div:
12970 return ((T *)this)->visit((const Div *)node, std::forward<Args>(args)...);
12971 case IRNodeType::Min:
12972 return ((T *)this)->visit((const Min *)node, std::forward<Args>(args)...);
12973 case IRNodeType::Max:
12974 return ((T *)this)->visit((const Max *)node, std::forward<Args>(args)...);
12975 case IRNodeType::EQ:
12976 return ((T *)this)->visit((const EQ *)node, std::forward<Args>(args)...);
12977 case IRNodeType::NE:
12978 return ((T *)this)->visit((const NE *)node, std::forward<Args>(args)...);
12979 case IRNodeType::LT:
12980 return ((T *)this)->visit((const LT *)node, std::forward<Args>(args)...);
12981 case IRNodeType::LE:
12982 return ((T *)this)->visit((const LE *)node, std::forward<Args>(args)...);
12983 case IRNodeType::GT:
12984 return ((T *)this)->visit((const GT *)node, std::forward<Args>(args)...);
12985 case IRNodeType::GE:
12986 return ((T *)this)->visit((const GE *)node, std::forward<Args>(args)...);
12987 case IRNodeType::And:
12988 return ((T *)this)->visit((const And *)node, std::forward<Args>(args)...);
12989 case IRNodeType::Or:
12990 return ((T *)this)->visit((const Or *)node, std::forward<Args>(args)...);
12991 case IRNodeType::Not:
12992 return ((T *)this)->visit((const Not *)node, std::forward<Args>(args)...);
12993 case IRNodeType::Select:
12994 return ((T *)this)->visit((const Select *)node, std::forward<Args>(args)...);
12995 case IRNodeType::Load:
12996 return ((T *)this)->visit((const Load *)node, std::forward<Args>(args)...);
12997 case IRNodeType::Ramp:
12998 return ((T *)this)->visit((const Ramp *)node, std::forward<Args>(args)...);
12999 case IRNodeType::Call:
13000 return ((T *)this)->visit((const Call *)node, std::forward<Args>(args)...);
13001 case IRNodeType::Let:
13002 return ((T *)this)->visit((const Let *)node, std::forward<Args>(args)...);
13003 case IRNodeType::Shuffle:
13004 return ((T *)this)->visit((const Shuffle *)node, std::forward<Args>(args)...);
13005 case IRNodeType::VectorReduce:
13006 return ((T *)this)->visit((const VectorReduce *)node, std::forward<Args>(args)...);
13007 // Explicitly list the Stmt types rather than using a
13008 // default case so that when new IR nodes are added we
13009 // don't miss them here.
13010 case IRNodeType::LetStmt:
13011 case IRNodeType::AssertStmt:
13012 case IRNodeType::ProducerConsumer:
13013 case IRNodeType::For:
13014 case IRNodeType::Acquire:
13015 case IRNodeType::Store:
13016 case IRNodeType::Provide:
13017 case IRNodeType::Allocate:
13018 case IRNodeType::Free:
13019 case IRNodeType::Realize:
13020 case IRNodeType::Block:
13021 case IRNodeType::Fork:
13022 case IRNodeType::IfThenElse:
13023 case IRNodeType::Evaluate:
13024 case IRNodeType::Prefetch:
13025 case IRNodeType::Atomic:
13026 internal_error << "Unreachable";
13027 }
13028 return ExprRet{};
13029 }
13030
13031 template<typename... Args>
13032 StmtRet dispatch_stmt(const BaseStmtNode *node, Args &&...args) {
13033 if (node == nullptr) {
13034 return StmtRet{};
13035 }
13036 switch (node->node_type) {
13037 case IRNodeType::IntImm:
13038 case IRNodeType::UIntImm:
13039 case IRNodeType::FloatImm:
13040 case IRNodeType::StringImm:
13041 case IRNodeType::Broadcast:
13042 case IRNodeType::Cast:
13043 case IRNodeType::Variable:
13044 case IRNodeType::Add:
13045 case IRNodeType::Sub:
13046 case IRNodeType::Mod:
13047 case IRNodeType::Mul:
13048 case IRNodeType::Div:
13049 case IRNodeType::Min:
13050 case IRNodeType::Max:
13051 case IRNodeType::EQ:
13052 case IRNodeType::NE:
13053 case IRNodeType::LT:
13054 case IRNodeType::LE:
13055 case IRNodeType::GT:
13056 case IRNodeType::GE:
13057 case IRNodeType::And:
13058 case IRNodeType::Or:
13059 case IRNodeType::Not:
13060 case IRNodeType::Select:
13061 case IRNodeType::Load:
13062 case IRNodeType::Ramp:
13063 case IRNodeType::Call:
13064 case IRNodeType::Let:
13065 case IRNodeType::Shuffle:
13066 case IRNodeType::VectorReduce:
13067 internal_error << "Unreachable";
13068 break;
13069 case IRNodeType::LetStmt:
13070 return ((T *)this)->visit((const LetStmt *)node, std::forward<Args>(args)...);
13071 case IRNodeType::AssertStmt:
13072 return ((T *)this)->visit((const AssertStmt *)node, std::forward<Args>(args)...);
13073 case IRNodeType::ProducerConsumer:
13074 return ((T *)this)->visit((const ProducerConsumer *)node, std::forward<Args>(args)...);
13075 case IRNodeType::For:
13076 return ((T *)this)->visit((const For *)node, std::forward<Args>(args)...);
13077 case IRNodeType::Acquire:
13078 return ((T *)this)->visit((const Acquire *)node, std::forward<Args>(args)...);
13079 case IRNodeType::Store:
13080 return ((T *)this)->visit((const Store *)node, std::forward<Args>(args)...);
13081 case IRNodeType::Provide:
13082 return ((T *)this)->visit((const Provide *)node, std::forward<Args>(args)...);
13083 case IRNodeType::Allocate:
13084 return ((T *)this)->visit((const Allocate *)node, std::forward<Args>(args)...);
13085 case IRNodeType::Free:
13086 return ((T *)this)->visit((const Free *)node, std::forward<Args>(args)...);
13087 case IRNodeType::Realize:
13088 return ((T *)this)->visit((const Realize *)node, std::forward<Args>(args)...);
13089 case IRNodeType::Block:
13090 return ((T *)this)->visit((const Block *)node, std::forward<Args>(args)...);
13091 case IRNodeType::Fork:
13092 return ((T *)this)->visit((const Fork *)node, std::forward<Args>(args)...);
13093 case IRNodeType::IfThenElse:
13094 return ((T *)this)->visit((const IfThenElse *)node, std::forward<Args>(args)...);
13095 case IRNodeType::Evaluate:
13096 return ((T *)this)->visit((const Evaluate *)node, std::forward<Args>(args)...);
13097 case IRNodeType::Prefetch:
13098 return ((T *)this)->visit((const Prefetch *)node, std::forward<Args>(args)...);
13099 case IRNodeType::Atomic:
13100 return ((T *)this)->visit((const Atomic *)node, std::forward<Args>(args)...);
13101 }
13102 return StmtRet{};
13103 }
13104
13105public:
13106 template<typename... Args>
13107 HALIDE_ALWAYS_INLINE StmtRet dispatch(const Stmt &s, Args &&...args) {
13108 return dispatch_stmt(s.get(), std::forward<Args>(args)...);
13109 }
13110
13111 template<typename... Args>
13112 HALIDE_ALWAYS_INLINE StmtRet dispatch(Stmt &&s, Args &&...args) {
13113 return dispatch_stmt(s.get(), std::forward<Args>(args)...);
13114 }
13115
13116 template<typename... Args>
13117 HALIDE_ALWAYS_INLINE ExprRet dispatch(const Expr &e, Args &&...args) {
13118 return dispatch_expr(e.get(), std::forward<Args>(args)...);
13119 }
13120
13121 template<typename... Args>
13122 HALIDE_ALWAYS_INLINE ExprRet dispatch(Expr &&e, Args &&...args) {
13123 return dispatch_expr(e.get(), std::forward<Args>(args)...);
13124 }
13125};
13126
13127} // namespace Internal
13128} // namespace Halide
13129
13130#endif
13131
13132namespace Halide {
13133namespace Internal {
13134
13135typedef std::map<std::string, Interval> DimBounds;
13136
13137const int64_t unknown = std::numeric_limits<int64_t>::min();
13138
13139/** Visitor for keeping track of functions that are directly called and the
13140 * arguments with which they are called. */
13141class FindAllCalls : public IRVisitor {
13142 using IRVisitor::visit;
13143
13144 void visit(const Call *call) override {
13145 if (call->call_type == Call::Halide || call->call_type == Call::Image) {
13146 funcs_called.insert(call->name);
13147 call_args.emplace_back(call->name, call->args);
13148 }
13149 for (size_t i = 0; i < call->args.size(); i++) {
13150 call->args[i].accept(this);
13151 }
13152 }
13153
13154public:
13155 std::set<std::string> funcs_called;
13156 std::vector<std::pair<std::string, std::vector<Expr>>> call_args;
13157};
13158
13159/** Return an int representation of 's'. Throw an error on failure. */
13160int string_to_int(const std::string &s);
13161
13162/** Substitute every variable in an Expr or a Stmt with its estimate
13163 * if specified. */
13164//@{
13165Expr substitute_var_estimates(Expr e);
13166Stmt substitute_var_estimates(Stmt s);
13167//@}
13168
13169/** Return the size of an interval. Return an undefined expr if the interval
13170 * is unbounded. */
13171Expr get_extent(const Interval &i);
13172
13173/** Return the size of an n-d box. */
13174Expr box_size(const Box &b);
13175
13176/** Helper function to print the bounds of a region. */
13177void disp_regions(const std::map<std::string, Box> &regions);
13178
13179/** Return the corresponding definition of a function given the stage. This
13180 * will throw an assertion if the function is an extern function (Extern Func
13181 * does not have definition). */
13182Definition get_stage_definition(const Function &f, int stage_num);
13183
13184/** Return the corresponding loop dimensions of a function given the stage.
13185 * For extern Func, this will return a list of size 1 containing the
13186 * dummy __outermost loop dimension. */
13187std::vector<Dim> &get_stage_dims(const Function &f, int stage_num);
13188
13189/** Add partial load costs to the corresponding function in the result costs. */
13190void combine_load_costs(std::map<std::string, Expr> &result,
13191 const std::map<std::string, Expr> &partial);
13192
13193/** Return the required bounds of an intermediate stage (f, stage_num) of
13194 * function 'f' given the bounds of the pure dimensions. */
13195DimBounds get_stage_bounds(const Function &f, int stage_num, const DimBounds &pure_bounds);
13196
13197/** Return the required bounds for all the stages of the function 'f'. Each entry
13198 * in the returned vector corresponds to a stage. */
13199std::vector<DimBounds> get_stage_bounds(const Function &f, const DimBounds &pure_bounds);
13200
13201/** Recursively inline all the functions in the set 'inlines' into the
13202 * expression 'e' and return the resulting expression. If 'order' is
13203 * passed, inlining will be done in the reverse order of function realization
13204 * to avoid extra inlining works. */
13205Expr perform_inline(Expr e, const std::map<std::string, Function> &env,
13206 const std::set<std::string> &inlines = std::set<std::string>(),
13207 const std::vector<std::string> &order = std::vector<std::string>());
13208
13209/** Return all functions that are directly called by a function stage (f, stage). */
13210std::set<std::string> get_parents(Function f, int stage);
13211
13212/** Return value of element within a map. This will assert if the element is not
13213 * in the map. */
13214// @{
13215template<typename K, typename V>
13216V get_element(const std::map<K, V> &m, const K &key) {
13217 const auto &iter = m.find(key);
13218 internal_assert(iter != m.end());
13219 return iter->second;
13220}
13221
13222template<typename K, typename V>
13223V &get_element(std::map<K, V> &m, const K &key) {
13224 const auto &iter = m.find(key);
13225 internal_assert(iter != m.end());
13226 return iter->second;
13227}
13228// @}
13229
13230/** If the cost of computing a Func is about the same as calling the Func,
13231 * inline the Func. Return true of any of the Funcs is inlined. */
13232bool inline_all_trivial_functions(const std::vector<Function> &outputs,
13233 const std::vector<std::string> &order,
13234 const std::map<std::string, Function> &env);
13235
13236/** Determine if a Func (order[index]) is only consumed by another single Func
13237 * in element-wise manner. If it is, return the name of the consumer Func;
13238 * otherwise, return an empty string. */
13239std::string is_func_called_element_wise(const std::vector<std::string> &order, size_t index,
13240 const std::map<std::string, Function> &env);
13241
13242/** Inline a Func if its values are only consumed by another single Func in
13243 * element-wise manner. */
13244bool inline_all_element_wise_functions(const std::vector<Function> &outputs,
13245 const std::vector<std::string> &order,
13246 const std::map<std::string, Function> &env);
13247
13248void propagate_estimate_test();
13249
13250} // namespace Internal
13251} // namespace Halide
13252
13253#endif
13254#ifndef HALIDE_BOUNDARY_CONDITIONS_H
13255#define HALIDE_BOUNDARY_CONDITIONS_H
13256
13257/** \file
13258 * Support for imposing boundary conditions on Halide::Funcs.
13259 */
13260
13261#include <vector>
13262
13263#ifndef HALIDE_FUNC_H
13264#define HALIDE_FUNC_H
13265
13266/** \file
13267 *
13268 * Defines Func - the front-end handle on a halide function, and related classes.
13269 */
13270
13271#ifndef HALIDE_JIT_MODULE_H
13272#define HALIDE_JIT_MODULE_H
13273
13274/** \file
13275 * Defines the struct representing lifetime and dependencies of
13276 * a JIT compiled halide pipeline
13277 */
13278
13279#include <map>
13280#include <memory>
13281
13282
13283namespace llvm {
13284class Module;
13285}
13286
13287namespace Halide {
13288
13289struct ExternCFunction;
13290struct JITExtern;
13291struct Target;
13292class Module;
13293
13294namespace Internal {
13295
13296class JITModuleContents;
13297struct LoweredFunc;
13298
13299struct JITModule {
13300 IntrusivePtr<JITModuleContents> jit_module;
13301
13302 struct Symbol {
13303 void *address = nullptr;
13304 Symbol() = default;
13305 explicit Symbol(void *address)
13306 : address(address) {
13307 }
13308 };
13309
13310 JITModule();
13311 JITModule(const Module &m, const LoweredFunc &fn,
13312 const std::vector<JITModule> &dependencies = std::vector<JITModule>());
13313
13314 /** Take a list of JITExterns and generate trampoline functions
13315 * which can be called dynamically via a function pointer that
13316 * takes an array of void *'s for each argument and the return
13317 * value.
13318 */
13319 static JITModule make_trampolines_module(const Target &target,
13320 const std::map<std::string, JITExtern> &externs,
13321 const std::string &suffix,
13322 const std::vector<JITModule> &deps);
13323
13324 /** The exports map of a JITModule contains all symbols which are
13325 * available to other JITModules which depend on this one. For
13326 * runtime modules, this is all of the symbols exported from the
13327 * runtime. For a JITted Func, it generally only contains the main
13328 * result Func of the compilation, which takes its name directly
13329 * from the Func declaration. One can also make a module which
13330 * contains no code itself but is just an exports maps providing
13331 * arbitrary pointers to functions or global variables to JITted
13332 * code. */
13333 const std::map<std::string, Symbol> &exports() const;
13334
13335 /** A pointer to the raw halide function. Its true type depends
13336 * on the Argument vector passed to CodeGen_LLVM::compile. Image
13337 * parameters become (halide_buffer_t *), and scalar parameters become
13338 * pointers to the appropriate values. The final argument is a
13339 * pointer to the halide_buffer_t defining the output. This will be nullptr for
13340 * a JITModule which has not yet been compiled or one that is not
13341 * a Halide Func compilation at all. */
13342 void *main_function() const;
13343
13344 /** Returns the Symbol structure for the routine documented in
13345 * main_function. Returning a Symbol allows access to the LLVM
13346 * type as well as the address. The address and type will be nullptr
13347 * if the module has not been compiled. */
13348 Symbol entrypoint_symbol() const;
13349
13350 /** Returns the Symbol structure for the argv wrapper routine
13351 * corresponding to the entrypoint. The argv wrapper is callable
13352 * via an array of void * pointers to the arguments for the
13353 * call. Returning a Symbol allows access to the LLVM type as well
13354 * as the address. The address and type will be nullptr if the module
13355 * has not been compiled. */
13356 Symbol argv_entrypoint_symbol() const;
13357
13358 /** A slightly more type-safe wrapper around the raw halide
13359 * module. Takes it arguments as an array of pointers that
13360 * correspond to the arguments to \ref main_function . This will
13361 * be nullptr for a JITModule which has not yet been compiled or one
13362 * that is not a Halide Func compilation at all. */
13363 // @{
13364 typedef int (*argv_wrapper)(const void **args);
13365 argv_wrapper argv_function() const;
13366 // @}
13367
13368 /** Add another JITModule to the dependency chain. Dependencies
13369 * are searched to resolve symbols not found in the current
13370 * compilation unit while JITting. */
13371 void add_dependency(JITModule &dep);
13372 /** Registers a single Symbol as available to modules which depend
13373 * on this one. The Symbol structure provides both the address and
13374 * the LLVM type for the function, which allows type safe linkage of
13375 * extenal routines. */
13376 void add_symbol_for_export(const std::string &name, const Symbol &extern_symbol);
13377 /** Registers a single function as available to modules which
13378 * depend on this one. This routine converts the ExternSignature
13379 * info into an LLVM type, which allows type safe linkage of
13380 * external routines. */
13381 void add_extern_for_export(const std::string &name,
13382 const ExternCFunction &extern_c_function);
13383
13384 /** Look up a symbol by name in this module or its dependencies. */
13385 Symbol find_symbol_by_name(const std::string &) const;
13386
13387 /** Take an llvm module and compile it. The requested exports will
13388 be available via the exports method. */
13389 void compile_module(std::unique_ptr<llvm::Module> mod,
13390 const std::string &function_name, const Target &target,
13391 const std::vector<JITModule> &dependencies = std::vector<JITModule>(),
13392 const std::vector<std::string> &requested_exports = std::vector<std::string>());
13393
13394 /** See JITSharedRuntime::memoization_cache_set_size */
13395 void memoization_cache_set_size(int64_t size) const;
13396
13397 /** See JITSharedRuntime::memoization_cache_evict */
13398 void memoization_cache_evict(uint64_t eviction_key) const;
13399
13400 /** See JITSharedRuntime::reuse_device_allocations */
13401 void reuse_device_allocations(bool) const;
13402
13403 /** Return true if compile_module has been called on this module. */
13404 bool compiled() const;
13405};
13406
13407typedef int (*halide_task)(void *user_context, int, uint8_t *);
13408
13409struct JITHandlers {
13410 void (*custom_print)(void *, const char *){nullptr};
13411 void *(*custom_malloc)(void *, size_t){nullptr};
13412 void (*custom_free)(void *, void *){nullptr};
13413 int (*custom_do_task)(void *, halide_task, int, uint8_t *){nullptr};
13414 int (*custom_do_par_for)(void *, halide_task, int, int, uint8_t *){nullptr};
13415 void (*custom_error)(void *, const char *){nullptr};
13416 int32_t (*custom_trace)(void *, const halide_trace_event_t *){nullptr};
13417 void *(*custom_get_symbol)(const char *name){nullptr};
13418 void *(*custom_load_library)(const char *name){nullptr};
13419 void *(*custom_get_library_symbol)(void *lib, const char *name){nullptr};
13420};
13421
13422struct JITUserContext {
13423 void *user_context;
13424 JITHandlers handlers;
13425};
13426
13427class JITSharedRuntime {
13428public:
13429 // Note only the first llvm::Module passed in here is used. The same shared runtime is used for all JIT.
13430 static std::vector<JITModule> get(llvm::Module *m, const Target &target, bool create = true);
13431 static void init_jit_user_context(JITUserContext &jit_user_context, void *user_context, const JITHandlers &handlers);
13432 static JITHandlers set_default_handlers(const JITHandlers &handlers);
13433
13434 /** Set the maximum number of bytes used by memoization caching.
13435 * If you are compiling statically, you should include HalideRuntime.h
13436 * and call halide_memoization_cache_set_size() instead.
13437 */
13438 static void memoization_cache_set_size(int64_t size);
13439
13440 /** Evict all cache entries that were tagged with the given
13441 * eviction_key in the memoize scheduling directive. If you are
13442 * compiling statically, you should include HalideRuntime.h and
13443 * call halide_memoization_cache_evict() instead.
13444 */
13445 static void memoization_cache_evict(uint64_t eviction_key);
13446
13447 /** Set whether or not Halide may hold onto and reuse device
13448 * allocations to avoid calling expensive device API allocation
13449 * functions. If you are compiling statically, you should include
13450 * HalideRuntime.h and call halide_reuse_device_allocations
13451 * instead. */
13452 static void reuse_device_allocations(bool);
13453
13454 static void release_all();
13455};
13456
13457void *get_symbol_address(const char *s);
13458
13459} // namespace Internal
13460} // namespace Halide
13461
13462#endif
13463#ifndef HALIDE_MODULE_H
13464#define HALIDE_MODULE_H
13465
13466/** \file
13467 *
13468 * Defines Module, an IR container that fully describes a Halide program.
13469 */
13470
13471#include <functional>
13472#include <map>
13473#include <memory>
13474#include <string>
13475
13476#ifndef HALIDE_EXTERNAL_CODE_H
13477#define HALIDE_EXTERNAL_CODE_H
13478
13479#include <vector>
13480
13481
13482namespace Halide {
13483
13484class ExternalCode {
13485private:
13486 enum Kind {
13487 LLVMBitcode,
13488 DeviceCode,
13489 CPlusPlusSource,
13490 } kind;
13491
13492 Target llvm_target; // For LLVMBitcode.
13493 DeviceAPI device_code_kind;
13494
13495 std::vector<uint8_t> code;
13496
13497 // Used for debugging and naming the module to llvm.
13498 std::string nametag;
13499
13500 ExternalCode(Kind kind, const Target &llvm_target, DeviceAPI device_api, const std::vector<uint8_t> &code, const std::string &name)
13501 : kind(kind), llvm_target(llvm_target), device_code_kind(device_api), code(code), nametag(name) {
13502 }
13503
13504public:
13505 /** Construct an ExternalCode container from llvm bitcode. The
13506 * result can be passed to Halide::Module::append to have the
13507 * contained bitcode linked with that module. The Module's target
13508 * must match the target argument here on architecture, bit width,
13509 * and operating system. The name is used as a unique identifier
13510 * for the external code and duplicates will be reduced to a
13511 * single instance. Halide does not do anything other than to
13512 * compare names for equality. To guarantee uniqueness in public
13513 * code, we suggest using a Java style inverted domain name
13514 * followed by organization specific naming. E.g.:
13515 * com.initech.y2k.5d2ac80aaf522eec6cb4b40f39fb923f9902bc7e */
13516 static ExternalCode bitcode_wrapper(const Target &cpu_type, const std::vector<uint8_t> &code, const std::string &name) {
13517 return ExternalCode(LLVMBitcode, cpu_type, DeviceAPI::None, code, name);
13518 }
13519
13520 /** Construct an ExternalCode container from GPU "source code."
13521 * This container can be used to insert its code into the GPU code
13522 * generated for a given DeviceAPI. The specific type of code
13523 * depends on the device API used as follows:
13524 * CUDA: llvm bitcode for PTX
13525 * OpenCL: OpenCL source code
13526 * GLSL: GLSL source code
13527 * OpenGLCompute: GLSL source code
13528 * Metal: Metal source code
13529 * Hexagon: llvm bitcode for Hexagon
13530 *
13531 * At present, this API is not fully working. See Issue:
13532 * https://github.com/halide/Halide/issues/1971
13533 *
13534 * The name is used as a unique identifier for the external code
13535 * and duplicates will be reduced to a single instance. Halide
13536 * does not do anything other than to compare names for
13537 * equality. To guarantee uniqueness in public code, we suggest
13538 * using a Java style inverted domain name followed by
13539 * organization specific naming. E.g.:
13540 * com.tyrell.nexus-6.53947db86ba97a9ca5ecd5e60052880945bfeb37 */
13541 static ExternalCode device_code_wrapper(DeviceAPI device_api, const std::vector<uint8_t> &code, const std::string &name) {
13542 return ExternalCode(DeviceCode, Target(), device_api, code, name);
13543 }
13544
13545 /** Construct an ExternalCode container from C++ source code. This
13546 * container can be used to insert its code into C++ output from
13547 * Halide.
13548 *
13549 * At present, this API is not fully working. See Issue:
13550 * https://github.com/halide/Halide/issues/1971
13551 *
13552 * The name is used as a unique identifier for the external code
13553 * and duplicates will be reduced to a single instance. Halide
13554 * does not do anything other than to compare names for
13555 * equality. To guarantee uniqueness in public code, we suggest
13556 * using a Java style inverted domain name followed by
13557 * organization specific naming. E.g.:
13558 * com.cyberdyne.skynet.78ad6c411d313f050f172cd3d440f23fdd797d0d */
13559 static ExternalCode c_plus_plus_code_wrapper(const std::vector<uint8_t> &code, const std::string &name) {
13560 return ExternalCode(CPlusPlusSource, Target(), DeviceAPI::None, code, name);
13561 }
13562
13563 /** Return true if this container holds llvm bitcode linkable with
13564 * code generated for the target argument. The matching is done
13565 * on the architecture, bit width, and operating system
13566 * only. Features are ignored. If the container is for
13567 * Target::ArchUnkonwn, it applies to all architectures -- meaning
13568 * it is generic llvm bitcode. If the OS is OSUnknown, it applies
13569 * to all operationg systems. The bit width must match.
13570 *
13571 * Ignoring feature flags isn't too important since generally
13572 * ExternalCode will be constructed in a Generator which has
13573 * access to the feature flags in effect and can select code
13574 * appropriately. */
13575 bool is_for_cpu_target(const Target &host) const {
13576 return kind == LLVMBitcode &&
13577 (llvm_target.arch == Target::ArchUnknown || llvm_target.arch == host.arch) &&
13578 (llvm_target.os == Target::OSUnknown || llvm_target.os == host.os) &&
13579 (llvm_target.bits == host.bits);
13580 }
13581
13582 /** True if this container holds code linkable with a code generated for a GPU. */
13583 bool is_for_device_api(DeviceAPI current_device) const {
13584 return kind == DeviceCode && device_code_kind == current_device;
13585 }
13586
13587 /** True if this container holds C++ source code for inclusion in
13588 * generating C++ output. */
13589 bool is_c_plus_plus_source() const {
13590 return kind == CPlusPlusSource;
13591 }
13592
13593 /** Retrieve the bytes of external code held by this container. */
13594 const std::vector<uint8_t> &contents() const {
13595 return code;
13596 }
13597
13598 /** Retrieve the name of this container. Used to ensure the same
13599 * piece of external code is only included once in linkage. */
13600 const std::string &name() const {
13601 return nametag;
13602 }
13603};
13604
13605} // namespace Halide
13606
13607#endif
13608#ifndef HALIDE_FUNCTION_H
13609#define HALIDE_FUNCTION_H
13610
13611/** \file
13612 * Defines the internal representation of a halide function and related classes
13613 */
13614#include <map>
13615#include <string>
13616#include <utility>
13617#include <vector>
13618
13619
13620namespace Halide {
13621
13622struct ExternFuncArgument;
13623
13624class Var;
13625
13626/** An enum to specify calling convention for extern stages. */
13627enum class NameMangling {
13628 Default, ///< Match whatever is specified in the Target
13629 C, ///< No name mangling
13630 CPlusPlus, ///< C++ name mangling
13631};
13632
13633namespace Internal {
13634
13635struct Call;
13636class Parameter;
13637
13638/** A reference-counted handle to Halide's internal representation of
13639 * a function. Similar to a front-end Func object, but with no
13640 * syntactic sugar to help with definitions. */
13641class Function {
13642 FunctionPtr contents;
13643
13644public:
13645 /** This lets you use a Function as a key in a map of the form
13646 * map<Function, Foo, Function::Compare> */
13647 struct Compare {
13648 bool operator()(const Function &a, const Function &b) const {
13649 internal_assert(a.contents.defined() && b.contents.defined());
13650 return a.contents < b.contents;
13651 }
13652 };
13653
13654 /** Construct a new function with no definitions and no name. This
13655 * constructor only exists so that you can make vectors of
13656 * functions, etc.
13657 */
13658 Function() = default;
13659
13660 /** Construct a new function with the given name */
13661 explicit Function(const std::string &n);
13662
13663 /** Construct a Function from an existing FunctionContents pointer. Must be non-null */
13664 explicit Function(const FunctionPtr &);
13665
13666 /** Get a handle on the halide function contents that this Function
13667 * represents. */
13668 FunctionPtr get_contents() const {
13669 return contents;
13670 }
13671
13672 /** Deep copy this Function into 'copy'. It recursively deep copies all called
13673 * functions, schedules, update definitions, extern func arguments, specializations,
13674 * and reduction domains. This method does not deep-copy the Parameter objects.
13675 * This method also takes a map of <old Function, deep-copied version> as input
13676 * and would use the deep-copied Function from the map if exists instead of
13677 * creating a new deep-copy to avoid creating deep-copies of the same Function
13678 * multiple times. If 'name' is specified, copy's name will be set to that.
13679 */
13680 // @{
13681 void deep_copy(const FunctionPtr &copy, std::map<FunctionPtr, FunctionPtr> &copied_map) const;
13682 void deep_copy(std::string name, const FunctionPtr &copy,
13683 std::map<FunctionPtr, FunctionPtr> &copied_map) const;
13684 // @}
13685
13686 /** Add a pure definition to this function. It may not already
13687 * have a definition. All the free variables in 'value' must
13688 * appear in the args list. 'value' must not depend on any
13689 * reduction domain */
13690 void define(const std::vector<std::string> &args, std::vector<Expr> values);
13691
13692 /** Add an update definition to this function. It must already
13693 * have a pure definition but not an update definition, and the
13694 * length of args must match the length of args used in the pure
13695 * definition. 'value' must depend on some reduction domain, and
13696 * may contain variables from that domain as well as pure
13697 * variables. Any pure variables must also appear as Variables in
13698 * the args array, and they must have the same name as the pure
13699 * definition's argument in the same index. */
13700 void define_update(const std::vector<Expr> &args, std::vector<Expr> values);
13701
13702 /** Accept a visitor to visit all of the definitions and arguments
13703 * of this function. */
13704 void accept(IRVisitor *visitor) const;
13705
13706 /** Accept a mutator to mutator all of the definitions and
13707 * arguments of this function. */
13708 void mutate(IRMutator *mutator);
13709
13710 /** Get the name of the function. */
13711 const std::string &name() const;
13712
13713 /** If this is a wrapper of another func, created by a chain of in
13714 * or clone_in calls, returns the name of the original
13715 * Func. Otherwise returns the name. */
13716 const std::string &origin_name() const;
13717
13718 /** Get a mutable handle to the init definition. */
13719 Definition &definition();
13720
13721 /** Get the init definition. */
13722 const Definition &definition() const;
13723
13724 /** Get the pure arguments. */
13725 const std::vector<std::string> &args() const;
13726
13727 /** Get the dimensionality. */
13728 int dimensions() const;
13729
13730 /** Get the number of outputs. */
13731 int outputs() const {
13732 return (int)output_types().size();
13733 }
13734
13735 /** Get the types of the outputs. */
13736 const std::vector<Type> &output_types() const;
13737
13738 /** Get the right-hand-side of the pure definition. Returns an
13739 * empty vector if there is no pure definition. */
13740 const std::vector<Expr> &values() const;
13741
13742 /** Does this function have a pure definition? */
13743 bool has_pure_definition() const;
13744
13745 /** Does this function *only* have a pure definition? */
13746 bool is_pure() const {
13747 return (has_pure_definition() &&
13748 !has_update_definition() &&
13749 !has_extern_definition());
13750 }
13751
13752 /** Is it legal to inline this function? */
13753 bool can_be_inlined() const;
13754
13755 /** Get a handle to the function-specific schedule for the purpose
13756 * of modifying it. */
13757 FuncSchedule &schedule();
13758
13759 /** Get a const handle to the function-specific schedule for inspecting it. */
13760 const FuncSchedule &schedule() const;
13761
13762 /** Get a handle on the output buffer used for setting constraints
13763 * on it. */
13764 const std::vector<Parameter> &output_buffers() const;
13765
13766 /** Get a mutable handle to the stage-specfic schedule for the update
13767 * stage. */
13768 StageSchedule &update_schedule(int idx = 0);
13769
13770 /** Get a mutable handle to this function's update definition at
13771 * index 'idx'. */
13772 Definition &update(int idx = 0);
13773
13774 /** Get a const reference to this function's update definition at
13775 * index 'idx'. */
13776 const Definition &update(int idx = 0) const;
13777
13778 /** Get a const reference to this function's update definitions. */
13779 const std::vector<Definition> &updates() const;
13780
13781 /** Does this function have an update definition? */
13782 bool has_update_definition() const;
13783
13784 /** Check if the function has an extern definition. */
13785 bool has_extern_definition() const;
13786
13787 /** Get the name mangling specified for the extern definition. */
13788 NameMangling extern_definition_name_mangling() const;
13789
13790 /** Make a call node to the extern definition. An error if the
13791 * function has no extern definition. */
13792 Expr make_call_to_extern_definition(const std::vector<Expr> &args,
13793 const Target &t) const;
13794
13795 /** Get the proxy Expr for the extern stage. This is an expression
13796 * known to have the same data access pattern as the extern
13797 * stage. It must touch at least all of the memory that the extern
13798 * stage does, though it is permissible for it to be conservative
13799 * and touch a superset. For most Functions, including those with
13800 * extern definitions, this will be an undefined Expr. */
13801 // @{
13802 Expr extern_definition_proxy_expr() const;
13803 Expr &extern_definition_proxy_expr();
13804 // @}
13805
13806 /** Add an external definition of this Func. */
13807 void define_extern(const std::string &function_name,
13808 const std::vector<ExternFuncArgument> &args,
13809 const std::vector<Type> &types,
13810 const std::vector<Var> &dims,
13811 NameMangling mangling, DeviceAPI device_api);
13812
13813 /** Retrive the arguments of the extern definition. */
13814 // @{
13815 const std::vector<ExternFuncArgument> &extern_arguments() const;
13816 std::vector<ExternFuncArgument> &extern_arguments();
13817 // @}
13818
13819 /** Get the name of the extern function called for an extern
13820 * definition. */
13821 const std::string &extern_function_name() const;
13822
13823 /** Get the DeviceAPI declared for an extern function. */
13824 DeviceAPI extern_function_device_api() const;
13825
13826 /** Test for equality of identity. */
13827 bool same_as(const Function &other) const {
13828 return contents.same_as(other.contents);
13829 }
13830
13831 /** Get a const handle to the debug filename. */
13832 const std::string &debug_file() const;
13833
13834 /** Get a handle to the debug filename. */
13835 std::string &debug_file();
13836
13837 /** Use an an extern argument to another function. */
13838 operator ExternFuncArgument() const;
13839
13840 /** Tracing calls and accessors, passed down from the Func
13841 * equivalents. */
13842 // @{
13843 void trace_loads();
13844 void trace_stores();
13845 void trace_realizations();
13846 void add_trace_tag(const std::string &trace_tag);
13847 bool is_tracing_loads() const;
13848 bool is_tracing_stores() const;
13849 bool is_tracing_realizations() const;
13850 const std::vector<std::string> &get_trace_tags() const;
13851 // @}
13852
13853 /** Replace this Function's LoopLevels with locked copies that
13854 * cannot be mutated further. */
13855 void lock_loop_levels();
13856
13857 /** Mark function as frozen, which means it cannot accept new
13858 * definitions. */
13859 void freeze();
13860
13861 /** Check if a function has been frozen. If so, it is an error to
13862 * add new definitions. */
13863 bool frozen() const;
13864
13865 /** Make a new Function with the same lifetime as this one, and
13866 * return a strong reference to it. Useful to create Functions which
13867 * have circular references to this one - e.g. the wrappers
13868 * produced by Func::in. */
13869 Function new_function_in_same_group(const std::string &);
13870
13871 /** Mark calls of this function by 'f' to be replaced with its wrapper
13872 * during the lowering stage. If the string 'f' is empty, it means replace
13873 * all calls to this function by all other functions (excluding itself) in
13874 * the pipeline with the wrapper. This will also freeze 'wrapper' to prevent
13875 * user from updating the values of the Function it wraps via the wrapper.
13876 * See \ref Func::in for more details. */
13877 // @{
13878 void add_wrapper(const std::string &f, Function &wrapper);
13879 const std::map<std::string, FunctionPtr> &wrappers() const;
13880 // @}
13881
13882 /** Check if a Function is a trivial wrapper around another
13883 * Function, Buffer, or Parameter. Returns the Call node if it
13884 * is. Otherwise returns null.
13885 */
13886 const Call *is_wrapper() const;
13887
13888 /** Replace every call to Functions in 'substitutions' keys by all Exprs
13889 * referenced in this Function to call to their substitute Functions (i.e.
13890 * the corresponding values in 'substitutions' map). */
13891 // @{
13892 Function &substitute_calls(const std::map<FunctionPtr, FunctionPtr> &substitutions);
13893 Function &substitute_calls(const Function &orig, const Function &substitute);
13894 // @}
13895
13896 /** Return true iff the name matches one of the Function's pure args. */
13897 bool is_pure_arg(const std::string &name) const;
13898};
13899
13900/** Deep copy an entire Function DAG. */
13901std::pair<std::vector<Function>, std::map<std::string, Function>> deep_copy(
13902 const std::vector<Function> &outputs,
13903 const std::map<std::string, Function> &env);
13904
13905} // namespace Internal
13906} // namespace Halide
13907
13908#endif
13909
13910namespace Halide {
13911
13912template<typename T>
13913class Buffer;
13914struct Target;
13915
13916/** Enums specifying various kinds of outputs that can be produced from a Halide Pipeline. */
13917enum class Output {
13918 assembly,
13919 bitcode,
13920 c_header,
13921 c_source,
13922 compiler_log,
13923 cpp_stub,
13924 featurization,
13925 llvm_assembly,
13926 object,
13927 python_extension,
13928 pytorch_wrapper,
13929 registration,
13930 schedule,
13931 static_library,
13932 stmt,
13933 stmt_html,
13934};
13935
13936/** Type of linkage a function in a lowered Halide module can have.
13937 Also controls whether auxiliary functions and metadata are generated. */
13938enum class LinkageType {
13939 External, ///< Visible externally.
13940 ExternalPlusMetadata, ///< Visible externally. Argument metadata and an argv wrapper are also generated.
13941 Internal, ///< Not visible externally, similar to 'static' linkage in C.
13942};
13943
13944namespace Internal {
13945
13946struct OutputInfo {
13947 std::string name, extension;
13948
13949 // `is_multi` indicates how these outputs are generated
13950 // when using the compile_to_multitarget_xxx() APIs (or via the
13951 // Generator command-line mode):
13952 //
13953 // - If `is_multi` is true, then a separate file of this Output type is
13954 // generated for each target in the multitarget (e.g. object files,
13955 // assembly files, etc). Each of the files will have a suffix appended
13956 // that is based on the specific subtarget.
13957 //
13958 // - If `is_multi` is false, then only one file of this Output type
13959 // regardless of how many targets are in the multitarget. No additional
13960 // suffix will be appended to the filename.
13961 //
13962 bool is_multi{false};
13963};
13964std::map<Output, const OutputInfo> get_output_info(const Target &target);
13965
13966/** Definition of an argument to a LoweredFunc. This is similar to
13967 * Argument, except it enables passing extra information useful to
13968 * some targets to LoweredFunc. */
13969struct LoweredArgument : public Argument {
13970 /** For scalar arguments, the modulus and remainder of this
13971 * argument. */
13972 ModulusRemainder alignment;
13973
13974 LoweredArgument() = default;
13975 explicit LoweredArgument(const Argument &arg)
13976 : Argument(arg) {
13977 }
13978 LoweredArgument(const std::string &_name, Kind _kind, const Type &_type, uint8_t _dimensions, const ArgumentEstimates &argument_estimates)
13979 : Argument(_name, _kind, _type, _dimensions, argument_estimates) {
13980 }
13981};
13982
13983/** Definition of a lowered function. This object provides a concrete
13984 * mapping between parameters used in the function body and their
13985 * declarations in the argument list. */
13986struct LoweredFunc {
13987 std::string name;
13988
13989 /** Arguments referred to in the body of this function. */
13990 std::vector<LoweredArgument> args;
13991
13992 /** Body of this function. */
13993 Stmt body;
13994
13995 /** The linkage of this function. */
13996 LinkageType linkage;
13997
13998 /** The name-mangling choice for the function. Defaults to using
13999 * the Target. */
14000 NameMangling name_mangling;
14001
14002 LoweredFunc(const std::string &name,
14003 const std::vector<LoweredArgument> &args,
14004 Stmt body,
14005 LinkageType linkage,
14006 NameMangling mangling = NameMangling::Default);
14007 LoweredFunc(const std::string &name,
14008 const std::vector<Argument> &args,
14009 Stmt body,
14010 LinkageType linkage,
14011 NameMangling mangling = NameMangling::Default);
14012};
14013
14014} // namespace Internal
14015
14016namespace Internal {
14017struct ModuleContents;
14018class CompilerLogger;
14019} // namespace Internal
14020
14021struct AutoSchedulerResults;
14022
14023/** A halide module. This represents IR containing lowered function
14024 * definitions and buffers. */
14025class Module {
14026 Internal::IntrusivePtr<Internal::ModuleContents> contents;
14027
14028public:
14029 Module(const std::string &name, const Target &target);
14030
14031 /** Get the target this module has been lowered for. */
14032 const Target &target() const;
14033
14034 /** The name of this module. This is used as the default filename
14035 * for output operations. */
14036 const std::string &name() const;
14037
14038 /** If this Module had an auto-generated schedule, return a read-only pointer
14039 * to the AutoSchedulerResults. If not, return nullptr. */
14040 const AutoSchedulerResults *get_auto_scheduler_results() const;
14041
14042 /** Return whether this module uses strict floating-point anywhere. */
14043 bool any_strict_float() const;
14044
14045 /** The declarations contained in this module. */
14046 // @{
14047 const std::vector<Buffer<void>> &buffers() const;
14048 const std::vector<Internal::LoweredFunc> &functions() const;
14049 std::vector<Internal::LoweredFunc> &functions();
14050 const std::vector<Module> &submodules() const;
14051 const std::vector<ExternalCode> &external_code() const;
14052 // @}
14053
14054 /** Return the function with the given name. If no such function
14055 * exists in this module, assert. */
14056 Internal::LoweredFunc get_function_by_name(const std::string &name) const;
14057
14058 /** Add a declaration to this module. */
14059 // @{
14060 void append(const Buffer<void> &buffer);
14061 void append(const Internal::LoweredFunc &function);
14062 void append(const Module &module);
14063 void append(const ExternalCode &external_code);
14064 // @}
14065
14066 /** Compile a halide Module to variety of outputs, depending on
14067 * the fields set in output_files. */
14068 void compile(const std::map<Output, std::string> &output_files) const;
14069
14070 /** Compile a halide Module to in-memory object code. Currently
14071 * only supports LLVM based compilation, but should be extended to
14072 * handle source code backends. */
14073 Buffer<uint8_t> compile_to_buffer() const;
14074
14075 /** Return a new module with all submodules compiled to buffers on
14076 * on the result Module. */
14077 Module resolve_submodules() const;
14078
14079 /** When generating metadata from this module, remap any occurrences
14080 * of 'from' into 'to'. */
14081 void remap_metadata_name(const std::string &from, const std::string &to) const;
14082
14083 /** Retrieve the metadata name map. */
14084 std::map<std::string, std::string> get_metadata_name_map() const;
14085
14086 /** Set the AutoSchedulerResults for the Module. It is an error to call this
14087 * multiple times for a given Module. */
14088 void set_auto_scheduler_results(const AutoSchedulerResults &results);
14089
14090 /** Set whether this module uses strict floating-point directives anywhere. */
14091 void set_any_strict_float(bool any_strict_float);
14092};
14093
14094/** Link a set of modules together into one module. */
14095Module link_modules(const std::string &name, const std::vector<Module> &modules);
14096
14097/** Create an object file containing the Halide runtime for a given target. For
14098 * use with Target::NoRuntime. Standalone runtimes are only compatible with
14099 * pipelines compiled by the same build of Halide used to call this function. */
14100void compile_standalone_runtime(const std::string &object_filename, const Target &t);
14101
14102/** Create an object and/or static library file containing the Halide runtime
14103 * for a given target. For use with Target::NoRuntime. Standalone runtimes are
14104 * only compatible with pipelines compiled by the same build of Halide used to
14105 * call this function. Return a map with just the actual outputs filled in
14106 * (typically, Output::object and/or Output::static_library).
14107 */
14108std::map<Output, std::string> compile_standalone_runtime(const std::map<Output, std::string> &output_files, const Target &t);
14109
14110using ModuleFactory = std::function<Module(const std::string &fn_name, const Target &target)>;
14111using CompilerLoggerFactory = std::function<std::unique_ptr<Internal::CompilerLogger>(const std::string &fn_name, const Target &target)>;
14112
14113void compile_multitarget(const std::string &fn_name,
14114 const std::map<Output, std::string> &output_files,
14115 const std::vector<Target> &targets,
14116 const std::vector<std::string> &suffixes,
14117 const ModuleFactory &module_factory,
14118 const CompilerLoggerFactory &compiler_logger_factory = nullptr);
14119
14120} // namespace Halide
14121
14122#endif
14123#ifndef HALIDE_PARAM_H
14124#define HALIDE_PARAM_H
14125
14126#include <type_traits>
14127
14128#ifndef HALIDE_EXTERNFUNCARGUMENT_H
14129#define HALIDE_EXTERNFUNCARGUMENT_H
14130
14131/** \file
14132 * Defines the internal representation of a halide ExternFuncArgument
14133 */
14134
14135
14136namespace Halide {
14137
14138/** An argument to an extern-defined Func. May be a Function, Buffer,
14139 * ImageParam or Expr. */
14140struct ExternFuncArgument {
14141 enum ArgType { UndefinedArg = 0,
14142 FuncArg,
14143 BufferArg,
14144 ExprArg,
14145 ImageParamArg };
14146 ArgType arg_type = UndefinedArg;
14147 Internal::FunctionPtr func;
14148 Buffer<> buffer;
14149 Expr expr;
14150 Internal::Parameter image_param;
14151
14152 ExternFuncArgument(Internal::FunctionPtr f)
14153 : arg_type(FuncArg), func(std::move(f)) {
14154 }
14155
14156 template<typename T>
14157 ExternFuncArgument(Buffer<T> b)
14158 : arg_type(BufferArg), buffer(b) {
14159 }
14160 ExternFuncArgument(Expr e)
14161 : arg_type(ExprArg), expr(std::move(e)) {
14162 }
14163 ExternFuncArgument(int e)
14164 : arg_type(ExprArg), expr(e) {
14165 }
14166 ExternFuncArgument(float e)
14167 : arg_type(ExprArg), expr(e) {
14168 }
14169
14170 ExternFuncArgument(const Internal::Parameter &p)
14171 : arg_type(ImageParamArg), image_param(p) {
14172 // Scalar params come in via the Expr constructor.
14173 internal_assert(p.is_buffer());
14174 }
14175 ExternFuncArgument() = default;
14176
14177 bool is_func() const {
14178 return arg_type == FuncArg;
14179 }
14180 bool is_expr() const {
14181 return arg_type == ExprArg;
14182 }
14183 bool is_buffer() const {
14184 return arg_type == BufferArg;
14185 }
14186 bool is_image_param() const {
14187 return arg_type == ImageParamArg;
14188 }
14189 bool defined() const {
14190 return arg_type != UndefinedArg;
14191 }
14192};
14193
14194} // namespace Halide
14195
14196#endif // HALIDE_EXTERNFUNCARGUMENT_H
14197
14198/** \file
14199 *
14200 * Classes for declaring scalar parameters to halide pipelines
14201 */
14202
14203namespace Halide {
14204
14205/** A scalar parameter to a halide pipeline. If you're jitting, this
14206 * should be bound to an actual value of type T using the set method
14207 * before you realize the function uses this. If you're statically
14208 * compiling, this param should appear in the argument list. */
14209template<typename T = void>
14210class Param {
14211 /** A reference-counted handle on the internal parameter object */
14212 Internal::Parameter param;
14213
14214 // This is a deliberately non-existent type that allows us to compile Param<>
14215 // but provide less-confusing error messages if you attempt to call get<> or set<>
14216 // without explicit types.
14217 struct DynamicParamType;
14218
14219 /** T unless T is (const) void, in which case pointer-to-useless-type.` */
14220 using not_void_T = typename std::conditional<std::is_void<T>::value, DynamicParamType *, T>::type;
14221
14222 void check_name() const {
14223 user_assert(param.name() != "__user_context")
14224 << "Param<void*>(\"__user_context\") "
14225 << "is no longer used to control whether Halide functions take explicit "
14226 << "user_context arguments. Use set_custom_user_context() when jitting, "
14227 << "or add Target::UserContext to the Target feature set when compiling ahead of time.";
14228 }
14229
14230 // Must be constexpr to allow use in case clauses.
14231 inline static constexpr int halide_type_code(halide_type_code_t code, int bits) {
14232 return (((int)code) << 8) | bits;
14233 }
14234
14235 // Allow all Param<> variants friend access to each other
14236 template<typename OTHER_TYPE>
14237 friend class Param;
14238
14239public:
14240 /** True if the Halide type is not void (or const void). */
14241 static constexpr bool has_static_type = !std::is_void<T>::value;
14242
14243 /** Get the Halide type of T. Callers should not use the result if
14244 * has_static_halide_type is false. */
14245 static Type static_type() {
14246 internal_assert(has_static_type);
14247 return type_of<T>();
14248 }
14249
14250 /** Construct a scalar parameter of type T with a unique
14251 * auto-generated name */
14252 // @{
14253 Param()
14254 : param(type_of<T>(), false, 0, Internal::make_entity_name(this, "Halide:.*:Param<.*>", 'p')) {
14255 static_assert(has_static_type, "Cannot use this ctor without an explicit type.");
14256 }
14257 explicit Param(Type t)
14258 : param(t, false, 0, Internal::make_entity_name(this, "Halide:.*:Param<.*>", 'p')) {
14259 static_assert(!has_static_type, "Cannot use this ctor with an explicit type.");
14260 }
14261 // @}
14262
14263 /** Construct a scalar parameter of type T with the given name. */
14264 // @{
14265 explicit Param(const std::string &n)
14266 : param(type_of<T>(), false, 0, n) {
14267 static_assert(has_static_type, "Cannot use this ctor without an explicit type.");
14268 check_name();
14269 }
14270 explicit Param(const char *n)
14271 : param(type_of<T>(), false, 0, n) {
14272 static_assert(has_static_type, "Cannot use this ctor without an explicit type.");
14273 check_name();
14274 }
14275 Param(Type t, const std::string &n)
14276 : param(t, false, 0, n) {
14277 static_assert(!has_static_type, "Cannot use this ctor with an explicit type.");
14278 check_name();
14279 }
14280 // @}
14281
14282 /** Construct a scalar parameter of type T an initial value of
14283 * 'val'. Only triggers for non-pointer types. */
14284 template<typename T2 = T, typename std::enable_if<!std::is_pointer<T2>::value>::type * = nullptr>
14285 explicit Param(not_void_T val)
14286 : param(type_of<T>(), false, 0, Internal::make_entity_name(this, "Halide:.*:Param<.*>", 'p')) {
14287 static_assert(has_static_type, "Cannot use this ctor without an explicit type.");
14288 set<not_void_T>(val);
14289 }
14290
14291 /** Construct a scalar parameter of type T with the given name
14292 * and an initial value of 'val'. */
14293 Param(const std::string &n, not_void_T val)
14294 : param(type_of<T>(), false, 0, n) {
14295 check_name();
14296 static_assert(has_static_type, "Cannot use this ctor without an explicit type.");
14297 set<not_void_T>(val);
14298 }
14299
14300 /** Construct a scalar parameter of type T with an initial value of 'val'
14301 * and a given min and max. */
14302 Param(not_void_T val, const Expr &min, const Expr &max)
14303 : param(type_of<T>(), false, 0, Internal::make_entity_name(this, "Halide:.*:Param<.*>", 'p')) {
14304 static_assert(has_static_type, "Cannot use this ctor without an explicit type.");
14305 set_range(min, max);
14306 set<not_void_T>(val);
14307 }
14308
14309 /** Construct a scalar parameter of type T with the given name
14310 * and an initial value of 'val' and a given min and max. */
14311 Param(const std::string &n, not_void_T val, const Expr &min, const Expr &max)
14312 : param(type_of<T>(), false, 0, n) {
14313 static_assert(has_static_type, "Cannot use this ctor without an explicit type.");
14314 check_name();
14315 set_range(min, max);
14316 set<not_void_T>(val);
14317 }
14318
14319 /** Construct a Param<void> from any other Param. */
14320 template<typename OTHER_TYPE, typename T2 = T, typename std::enable_if<std::is_void<T2>::value>::type * = nullptr>
14321 Param(const Param<OTHER_TYPE> &other)
14322 : param(other.param) {
14323 // empty
14324 }
14325
14326 /** Construct a Param<non-void> from a Param with matching type.
14327 * (Do the check at runtime so that we can assign from Param<void> if the types are compatible.) */
14328 template<typename OTHER_TYPE, typename T2 = T, typename std::enable_if<!std::is_void<T2>::value>::type * = nullptr>
14329 Param(const Param<OTHER_TYPE> &other)
14330 : param(other.param) {
14331 user_assert(other.type() == type_of<T>())
14332 << "Param<" << type_of<T>() << "> cannot be constructed from a Param with type " << other.type();
14333 }
14334
14335 /** Copy a Param<void> from any other Param. */
14336 template<typename OTHER_TYPE, typename T2 = T, typename std::enable_if<std::is_void<T2>::value>::type * = nullptr>
14337 Param<T> &operator=(const Param<OTHER_TYPE> &other) {
14338 param = other.param;
14339 return *this;
14340 }
14341
14342 /** Copy a Param<non-void> from a Param with matching type.
14343 * (Do the check at runtime so that we can assign from Param<void> if the types are compatible.) */
14344 template<typename OTHER_TYPE, typename T2 = T, typename std::enable_if<!std::is_void<T2>::value>::type * = nullptr>
14345 Param<T> &operator=(const Param<OTHER_TYPE> &other) {
14346 user_assert(other.type() == type_of<T>())
14347 << "Param<" << type_of<T>() << "> cannot be copied from a Param with type " << other.type();
14348 param = other.param;
14349 return *this;
14350 }
14351
14352 /** Get the name of this parameter */
14353 const std::string &name() const {
14354 return param.name();
14355 }
14356
14357 /** Get the current value of this parameter. Only meaningful when jitting.
14358 Asserts if type does not exactly match the Parameter's type. */
14359 template<typename T2 = not_void_T>
14360 HALIDE_NO_USER_CODE_INLINE T2 get() const {
14361 return param.scalar<T2>();
14362 }
14363
14364 /** Set the current value of this parameter. Only meaningful when jitting.
14365 Asserts if type is not losslessly-convertible to Parameter's type. */
14366 // @{
14367 template<typename SOME_TYPE, typename T2 = T, typename std::enable_if<!std::is_void<T2>::value>::type * = nullptr>
14368 HALIDE_NO_USER_CODE_INLINE void set(const SOME_TYPE &val) {
14369 user_assert(Internal::IsRoundtrippable<T>::value(val))
14370 << "The value " << val << " cannot be losslessly converted to type " << type();
14371 param.set_scalar<T>(val);
14372 }
14373
14374 // Specialized version for when T = void (thus the type is only known at runtime,
14375 // not compiletime). Note that this actually works fine for all Params; we specialize
14376 // it just to reduce code size for the common case of T != void.
14377 template<typename SOME_TYPE, typename T2 = T, typename std::enable_if<std::is_void<T2>::value>::type * = nullptr>
14378 HALIDE_NO_USER_CODE_INLINE void set(const SOME_TYPE &val) {
14379#define HALIDE_HANDLE_TYPE_DISPATCH(CODE, BITS, TYPE) \
14380 case halide_type_code(CODE, BITS): \
14381 user_assert(Internal::IsRoundtrippable<TYPE>::value(val)) \
14382 << "The value " << val << " cannot be losslessly converted to type " << type; \
14383 param.set_scalar<TYPE>(Internal::StaticCast<TYPE>::value(val)); \
14384 break;
14385
14386 const Type type = param.type();
14387 switch (halide_type_code(type.code(), type.bits())) {
14388 HALIDE_HANDLE_TYPE_DISPATCH(halide_type_float, 32, float)
14389 HALIDE_HANDLE_TYPE_DISPATCH(halide_type_float, 64, double)
14390 HALIDE_HANDLE_TYPE_DISPATCH(halide_type_int, 8, int8_t)
14391 HALIDE_HANDLE_TYPE_DISPATCH(halide_type_int, 16, int16_t)
14392 HALIDE_HANDLE_TYPE_DISPATCH(halide_type_int, 32, int32_t)
14393 HALIDE_HANDLE_TYPE_DISPATCH(halide_type_int, 64, int64_t)
14394 HALIDE_HANDLE_TYPE_DISPATCH(halide_type_uint, 1, bool)
14395 HALIDE_HANDLE_TYPE_DISPATCH(halide_type_uint, 8, uint8_t)
14396 HALIDE_HANDLE_TYPE_DISPATCH(halide_type_uint, 16, uint16_t)
14397 HALIDE_HANDLE_TYPE_DISPATCH(halide_type_uint, 32, uint32_t)
14398 HALIDE_HANDLE_TYPE_DISPATCH(halide_type_uint, 64, uint64_t)
14399 HALIDE_HANDLE_TYPE_DISPATCH(halide_type_handle, 64, uint64_t) // Handle types are always set via set_scalar<uint64_t>, not set_scalar<void*>
14400 default:
14401 internal_error << "Unsupported type in Param::set<" << type << ">\n";
14402 }
14403
14404#undef HALIDE_HANDLE_TYPE_DISPATCH
14405 }
14406 // @}
14407
14408 /** Get the halide type of the Param */
14409 Type type() const {
14410 return param.type();
14411 }
14412
14413 /** Get or set the possible range of this parameter. Use undefined
14414 * Exprs to mean unbounded. */
14415 // @{
14416 void set_range(const Expr &min, const Expr &max) {
14417 set_min_value(min);
14418 set_max_value(max);
14419 }
14420
14421 void set_min_value(Expr min) {
14422 if (min.defined() && min.type() != param.type()) {
14423 min = Internal::Cast::make(param.type(), min);
14424 }
14425 param.set_min_value(min);
14426 }
14427
14428 void set_max_value(Expr max) {
14429 if (max.defined() && max.type() != param.type()) {
14430 max = Internal::Cast::make(param.type(), max);
14431 }
14432 param.set_max_value(max);
14433 }
14434
14435 Expr min_value() const {
14436 return param.min_value();
14437 }
14438
14439 Expr max_value() const {
14440 return param.max_value();
14441 }
14442 // @}
14443
14444 template<typename SOME_TYPE>
14445 void set_estimate(const SOME_TYPE &value) {
14446 user_assert(Internal::IsRoundtrippable<T>::value(value))
14447 << "The value " << value << " cannot be losslessly converted to type " << type();
14448 param.set_estimate(Expr(value));
14449 }
14450
14451 /** You can use this parameter as an expression in a halide
14452 * function definition */
14453 operator Expr() const {
14454 return Internal::Variable::make(param.type(), name(), param);
14455 }
14456
14457 /** Using a param as the argument to an external stage treats it
14458 * as an Expr */
14459 operator ExternFuncArgument() const {
14460 return Expr(*this);
14461 }
14462
14463 /** Construct the appropriate argument matching this parameter,
14464 * for the purpose of generating the right type signature when
14465 * statically compiling halide pipelines. */
14466 operator Argument() const {
14467 return Argument(name(), Argument::InputScalar, type(), 0,
14468 param.get_argument_estimates());
14469 }
14470
14471 const Internal::Parameter &parameter() const {
14472 return param;
14473 }
14474
14475 Internal::Parameter &parameter() {
14476 return param;
14477 }
14478};
14479
14480/** Returns an Expr corresponding to the user context passed to
14481 * the function (if any). It is rare that this function is necessary
14482 * (e.g. to pass the user context to an extern function written in C). */
14483inline Expr user_context_value() {
14484 return Internal::Variable::make(Handle(), "__user_context",
14485 Internal::Parameter(Handle(), false, 0, "__user_context"));
14486}
14487
14488} // namespace Halide
14489
14490#endif
14491#ifndef HALIDE_PIPELINE_H
14492#define HALIDE_PIPELINE_H
14493
14494/** \file
14495 *
14496 * Defines the front-end class representing an entire Halide imaging
14497 * pipeline.
14498 */
14499
14500#include <map>
14501#include <vector>
14502
14503#ifndef HALIDE_PARAM_MAP_H
14504#define HALIDE_PARAM_MAP_H
14505
14506/** \file
14507 * Defines a collection of parameters to be passed as formal arguments
14508 * to a JIT invocation.
14509 */
14510#include <map>
14511
14512
14513namespace Halide {
14514
14515class ImageParam;
14516
14517class ParamMap {
14518public:
14519 struct ParamMapping {
14520 const Internal::Parameter *parameter{nullptr};
14521 const ImageParam *image_param{nullptr};
14522 halide_scalar_value_t value;
14523 Buffer<> buf;
14524 Buffer<> *buf_out_param;
14525
14526 template<typename T>
14527 ParamMapping(const Param<T> &p, const T &val)
14528 : parameter(&p.parameter()) {
14529 *((T *)&value) = val;
14530 }
14531
14532 ParamMapping(const ImageParam &p, Buffer<> &buf)
14533 : image_param(&p), buf(buf), buf_out_param(nullptr) {
14534 }
14535
14536 template<typename T>
14537 ParamMapping(const ImageParam &p, Buffer<T> &buf)
14538 : image_param(&p), buf(buf), buf_out_param(nullptr) {
14539 }
14540
14541 ParamMapping(const ImageParam &p, Buffer<> *buf_ptr)
14542 : image_param(&p), buf_out_param(buf_ptr) {
14543 }
14544
14545 template<typename T>
14546 ParamMapping(const ImageParam &p, Buffer<T> *buf_ptr)
14547 : image_param(&p), buf_out_param((Buffer<> *)buf_ptr) {
14548 }
14549 };
14550
14551private:
14552 struct ParamArg {
14553 Internal::Parameter mapped_param;
14554 Buffer<> *buf_out_param = nullptr;
14555
14556 ParamArg() = default;
14557 ParamArg(const ParamMapping &pm)
14558 : mapped_param(pm.parameter->type(), false, 0, pm.parameter->name()) {
14559 mapped_param.set_scalar(pm.parameter->type(), pm.value);
14560 }
14561 ParamArg(Buffer<> *buf_ptr)
14562 : buf_out_param(buf_ptr) {
14563 }
14564 ParamArg(const ParamArg &) = default;
14565 };
14566 mutable std::map<const Internal::Parameter, ParamArg> mapping;
14567
14568 void set(const ImageParam &p, const Buffer<> &buf, Buffer<> *buf_out_param);
14569
14570public:
14571 ParamMap() = default;
14572
14573 ParamMap(const std::initializer_list<ParamMapping> &init);
14574
14575 template<typename T>
14576 void set(const Param<T> &p, T val) {
14577 Internal::Parameter v(p.type(), false, 0, p.name());
14578 v.set_scalar<T>(val);
14579 ParamArg pa;
14580 pa.mapped_param = v;
14581 pa.buf_out_param = nullptr;
14582 mapping[p.parameter()] = pa;
14583 }
14584
14585 void set(const ImageParam &p, const Buffer<> &buf) {
14586 set(p, buf, nullptr);
14587 }
14588
14589 size_t size() const {
14590 return mapping.size();
14591 }
14592
14593 /** If there is an entry in the ParamMap for this Parameter, return it.
14594 * Otherwise return the parameter itself. */
14595 // @{
14596 const Internal::Parameter &map(const Internal::Parameter &p, Buffer<> *&buf_out_param) const;
14597
14598 Internal::Parameter &map(Internal::Parameter &p, Buffer<> *&buf_out_param) const;
14599 // @}
14600
14601 /** A const ref to an empty ParamMap. Useful for default function
14602 * arguments, which would otherwise require a copy constructor
14603 * (with llvm in c++98 mode) */
14604 static const ParamMap &empty_map() {
14605 static ParamMap empty_param_map;
14606 return empty_param_map;
14607 }
14608};
14609
14610} // namespace Halide
14611
14612#endif
14613#ifndef HALIDE_REALIZATION_H
14614#define HALIDE_REALIZATION_H
14615
14616#include <cstdint>
14617#include <vector>
14618
14619
14620/** \file
14621 *
14622 * Defines Realization - a vector of Buffer for use in pipelines with multiple outputs.
14623 */
14624
14625namespace Halide {
14626
14627template<typename T>
14628class Buffer;
14629
14630/** A Realization is a vector of references to existing Buffer objects.
14631 * A pipeline with multiple outputs realize to a Realization. */
14632class Realization {
14633private:
14634 std::vector<Buffer<void>> images;
14635
14636public:
14637 /** The number of images in the Realization. */
14638 size_t size() const;
14639
14640 /** Get a const reference to one of the images. */
14641 const Buffer<void> &operator[](size_t x) const;
14642
14643 /** Get a reference to one of the images. */
14644 Buffer<void> &operator[](size_t x);
14645
14646 /** Single-element realizations are implicitly castable to Buffers. */
14647 template<typename T>
14648 operator Buffer<T>() const {
14649 return images[0];
14650 }
14651
14652 /** Construct a Realization that acts as a reference to some
14653 * existing Buffers. The element type of the Buffers may not be
14654 * const. */
14655 template<typename T,
14656 typename... Args,
14657 typename = typename std::enable_if<Internal::all_are_convertible<Buffer<void>, Args...>::value>::type>
14658 Realization(Buffer<T> &a, Args &&...args) {
14659 images = std::vector<Buffer<void>>({a, args...});
14660 }
14661
14662 /** Construct a Realization that refers to the buffers in an
14663 * existing vector of Buffer<> */
14664 explicit Realization(std::vector<Buffer<void>> &e);
14665
14666 /** Call device_sync() for all Buffers in the Realization.
14667 * If one of the calls returns an error, subsequent Buffers won't have
14668 * device_sync called; thus callers should consider a nonzero return
14669 * code to mean that potentially all of the Buffers are in an indeterminate
14670 * state of sync.
14671 * Calling this explicitly should rarely be necessary, except for profiling. */
14672 int device_sync(void *ctx = nullptr);
14673};
14674
14675} // namespace Halide
14676
14677#endif
14678
14679namespace Halide {
14680
14681struct Argument;
14682class Func;
14683struct PipelineContents;
14684
14685/** A struct representing the machine parameters to generate the auto-scheduled
14686 * code for. */
14687struct MachineParams {
14688 /** Maximum level of parallelism avalaible. */
14689 int parallelism;
14690 /** Size of the last-level cache (in bytes). */
14691 uint64_t last_level_cache_size;
14692 /** Indicates how much more expensive is the cost of a load compared to
14693 * the cost of an arithmetic operation at last level cache. */
14694 float balance;
14695
14696 explicit MachineParams(int parallelism, uint64_t llc, float balance)
14697 : parallelism(parallelism), last_level_cache_size(llc), balance(balance) {
14698 }
14699
14700 /** Default machine parameters for generic CPU architecture. */
14701 static MachineParams generic();
14702
14703 /** Convert the MachineParams into canonical string form. */
14704 std::string to_string() const;
14705
14706 /** Reconstruct a MachineParams from canonical string form. */
14707 explicit MachineParams(const std::string &s);
14708};
14709
14710namespace Internal {
14711class IRMutator;
14712} // namespace Internal
14713
14714/**
14715 * Used to determine if the output printed to file should be as a normal string
14716 * or as an HTML file which can be opened in a browerser and manipulated via JS and CSS.*/
14717enum StmtOutputFormat {
14718 Text,
14719 HTML
14720};
14721
14722namespace {
14723// Helper for deleting custom lowering passes. In the header so that
14724// it goes in user code on windows, where you can have multiple heaps.
14725template<typename T>
14726void delete_lowering_pass(T *pass) {
14727 delete pass;
14728}
14729} // namespace
14730
14731/** A custom lowering pass. See Pipeline::add_custom_lowering_pass. */
14732struct CustomLoweringPass {
14733 Internal::IRMutator *pass;
14734 std::function<void()> deleter;
14735};
14736
14737struct JITExtern;
14738
14739struct AutoSchedulerResults {
14740 std::string scheduler_name; // name of the autoscheduler used
14741 Target target; // Target specified to the autoscheduler
14742 std::string machine_params_string; // MachineParams specified to the autoscheduler (in string form)
14743 std::string schedule_source; // The C++ source code of the generated schedule
14744 std::vector<uint8_t> featurization; // The featurization of the pipeline (if any)
14745};
14746
14747class Pipeline;
14748
14749using AutoSchedulerFn = std::function<void(const Pipeline &, const Target &, const MachineParams &, AutoSchedulerResults *outputs)>;
14750
14751/** A class representing a Halide pipeline. Constructed from the Func
14752 * or Funcs that it outputs. */
14753class Pipeline {
14754public:
14755 struct RealizationArg {
14756 // Only one of the following may be non-null
14757 Realization *r{nullptr};
14758 halide_buffer_t *buf{nullptr};
14759 std::unique_ptr<std::vector<Buffer<>>> buffer_list;
14760
14761 RealizationArg(Realization &r)
14762 : r(&r) {
14763 }
14764 RealizationArg(Realization &&r)
14765 : r(&r) {
14766 }
14767 RealizationArg(halide_buffer_t *buf)
14768 : buf(buf) {
14769 }
14770 template<typename T, int D>
14771 RealizationArg(Runtime::Buffer<T, D> &dst)
14772 : buf(dst.raw_buffer()) {
14773 }
14774 template<typename T>
14775 HALIDE_NO_USER_CODE_INLINE RealizationArg(Buffer<T> &dst)
14776 : buf(dst.raw_buffer()) {
14777 }
14778 template<typename T, typename... Args,
14779 typename = typename std::enable_if<Internal::all_are_convertible<Buffer<>, Args...>::value>::type>
14780 RealizationArg(Buffer<T> &a, Args &&...args) {
14781 buffer_list.reset(new std::vector<Buffer<>>({a, args...}));
14782 }
14783 RealizationArg(RealizationArg &&from) = default;
14784
14785 size_t size() const {
14786 if (r != nullptr) {
14787 return r->size();
14788 } else if (buffer_list) {
14789 return buffer_list->size();
14790 }
14791 return 1;
14792 }
14793 };
14794
14795private:
14796 Internal::IntrusivePtr<PipelineContents> contents;
14797
14798 struct JITCallArgs; // Opaque structure to optimize away dynamic allocation in this path.
14799
14800 // For the three method below, precisely one of the first two args should be non-null
14801 void prepare_jit_call_arguments(RealizationArg &output, const Target &target, const ParamMap &param_map,
14802 void *user_context, bool is_bounds_inference, JITCallArgs &args_result);
14803
14804 static std::vector<Internal::JITModule> make_externs_jit_module(const Target &target,
14805 std::map<std::string, JITExtern> &externs_in_out);
14806
14807 static std::map<std::string, AutoSchedulerFn> &get_autoscheduler_map();
14808
14809 static std::string &get_default_autoscheduler_name();
14810
14811 static AutoSchedulerFn find_autoscheduler(const std::string &autoscheduler_name);
14812
14813 int call_jit_code(const Target &target, const JITCallArgs &args);
14814
14815 // Get the value of contents->jit_target, but reality-check that the contents
14816 // sensibly match the value. Return Target() if not jitted.
14817 Target get_compiled_jit_target() const;
14818
14819public:
14820 /** Make an undefined Pipeline object. */
14821 Pipeline();
14822
14823 /** Make a pipeline that computes the given Func. Schedules the
14824 * Func compute_root(). */
14825 Pipeline(const Func &output);
14826
14827 /** Make a pipeline that computes the givens Funcs as
14828 * outputs. Schedules the Funcs compute_root(). */
14829 Pipeline(const std::vector<Func> &outputs);
14830
14831 std::vector<Argument> infer_arguments(const Internal::Stmt &body);
14832
14833 /** Get the Funcs this pipeline outputs. */
14834 std::vector<Func> outputs() const;
14835
14836 /** Generate a schedule for the pipeline using the currently-default autoscheduler. */
14837 AutoSchedulerResults auto_schedule(const Target &target,
14838 const MachineParams &arch_params = MachineParams::generic());
14839
14840 /** Generate a schedule for the pipeline using the specified autoscheduler. */
14841 AutoSchedulerResults auto_schedule(const std::string &autoscheduler_name,
14842 const Target &target,
14843 const MachineParams &arch_params = MachineParams::generic());
14844
14845 /** Add a new the autoscheduler method with the given name. Does not affect the current default autoscheduler.
14846 * It is an error to call this with the same name multiple times. */
14847 static void add_autoscheduler(const std::string &autoscheduler_name, const AutoSchedulerFn &autoscheduler);
14848
14849 /** Globally set the default autoscheduler method to use whenever
14850 * autoscheduling any Pipeline when no name is specified. If the autoscheduler_name isn't in the
14851 * current table of known autoschedulers, assert-fail.
14852 *
14853 * At this time, well-known autoschedulers include:
14854 * "Mullapudi2016" -- heuristics-based; the first working autoscheduler; currently built in to libHalide
14855 * see http://graphics.cs.cmu.edu/projects/halidesched/
14856 * "Adams2019" -- aka "the ML autoscheduler"; currently located in apps/autoscheduler
14857 * see https://halide-lang.org/papers/autoscheduler2019.html
14858 * "Li2018" -- aka "the gradient autoscheduler"; currently located in apps/gradient_autoscheduler.
14859 * see https://people.csail.mit.edu/tzumao/gradient_halide
14860 */
14861 static void set_default_autoscheduler_name(const std::string &autoscheduler_name);
14862
14863 /** Return handle to the index-th Func within the pipeline based on the
14864 * topological order. */
14865 Func get_func(size_t index);
14866
14867 /** Compile and generate multiple target files with single call.
14868 * Deduces target files based on filenames specified in
14869 * output_files map.
14870 */
14871 void compile_to(const std::map<Output, std::string> &output_files,
14872 const std::vector<Argument> &args,
14873 const std::string &fn_name,
14874 const Target &target);
14875
14876 /** Statically compile a pipeline to llvm bitcode, with the given
14877 * filename (which should probably end in .bc), type signature,
14878 * and C function name. If you're compiling a pipeline with a
14879 * single output Func, see also Func::compile_to_bitcode. */
14880 void compile_to_bitcode(const std::string &filename,
14881 const std::vector<Argument> &args,
14882 const std::string &fn_name,
14883 const Target &target = get_target_from_environment());
14884
14885 /** Statically compile a pipeline to llvm assembly, with the given
14886 * filename (which should probably end in .ll), type signature,
14887 * and C function name. If you're compiling a pipeline with a
14888 * single output Func, see also Func::compile_to_llvm_assembly. */
14889 void compile_to_llvm_assembly(const std::string &filename,
14890 const std::vector<Argument> &args,
14891 const std::string &fn_name,
14892 const Target &target = get_target_from_environment());
14893
14894 /** Statically compile a pipeline with multiple output functions to an
14895 * object file, with the given filename (which should probably end in
14896 * .o or .obj), type signature, and C function name (which defaults to
14897 * the same name as this halide function. You probably don't want to
14898 * use this directly; call compile_to_static_library or compile_to_file instead. */
14899 void compile_to_object(const std::string &filename,
14900 const std::vector<Argument> &,
14901 const std::string &fn_name,
14902 const Target &target = get_target_from_environment());
14903
14904 /** Emit a header file with the given filename for a pipeline. The
14905 * header will define a function with the type signature given by
14906 * the second argument, and a name given by the third. You don't
14907 * actually have to have defined any of these functions yet to
14908 * call this. You probably don't want to use this directly; call
14909 * compile_to_static_library or compile_to_file instead. */
14910 void compile_to_header(const std::string &filename,
14911 const std::vector<Argument> &,
14912 const std::string &fn_name,
14913 const Target &target = get_target_from_environment());
14914
14915 /** Statically compile a pipeline to text assembly equivalent to
14916 * the object file generated by compile_to_object. This is useful
14917 * for checking what Halide is producing without having to
14918 * disassemble anything, or if you need to feed the assembly into
14919 * some custom toolchain to produce an object file. */
14920 void compile_to_assembly(const std::string &filename,
14921 const std::vector<Argument> &args,
14922 const std::string &fn_name,
14923 const Target &target = get_target_from_environment());
14924
14925 /** Statically compile a pipeline to C source code. This is useful
14926 * for providing fallback code paths that will compile on many
14927 * platforms. Vectorization will fail, and parallelization will
14928 * produce serial code. */
14929 void compile_to_c(const std::string &filename,
14930 const std::vector<Argument> &,
14931 const std::string &fn_name,
14932 const Target &target = get_target_from_environment());
14933
14934 /** Write out an internal representation of lowered code. Useful
14935 * for analyzing and debugging scheduling. Can emit html or plain
14936 * text. */
14937 void compile_to_lowered_stmt(const std::string &filename,
14938 const std::vector<Argument> &args,
14939 StmtOutputFormat fmt = Text,
14940 const Target &target = get_target_from_environment());
14941
14942 /** Write out the loop nests specified by the schedule for this
14943 * Pipeline's Funcs. Helpful for understanding what a schedule is
14944 * doing. */
14945 void print_loop_nest();
14946
14947 /** Compile to object file and header pair, with the given
14948 * arguments. */
14949 void compile_to_file(const std::string &filename_prefix,
14950 const std::vector<Argument> &args,
14951 const std::string &fn_name,
14952 const Target &target = get_target_from_environment());
14953
14954 /** Compile to static-library file and header pair, with the given
14955 * arguments. */
14956 void compile_to_static_library(const std::string &filename_prefix,
14957 const std::vector<Argument> &args,
14958 const std::string &fn_name,
14959 const Target &target = get_target_from_environment());
14960
14961 /** Compile to static-library file and header pair once for each target;
14962 * each resulting function will be considered (in order) via halide_can_use_target_features()
14963 * at runtime, with the first appropriate match being selected for subsequent use.
14964 * This is typically useful for specializations that may vary unpredictably by machine
14965 * (e.g., SSE4.1/AVX/AVX2 on x86 desktop machines).
14966 * All targets must have identical arch-os-bits.
14967 */
14968 void compile_to_multitarget_static_library(const std::string &filename_prefix,
14969 const std::vector<Argument> &args,
14970 const std::vector<Target> &targets);
14971
14972 /** Like compile_to_multitarget_static_library(), except that the object files
14973 * are all output as object files (rather than bundled into a static library).
14974 *
14975 * `suffixes` is an optional list of strings to use for as the suffix for each object
14976 * file. If nonempty, it must be the same length as `targets`. (If empty, Target::to_string()
14977 * will be used for each suffix.)
14978 *
14979 * Note that if `targets.size()` > 1, the wrapper code (to select the subtarget)
14980 * will be generated with the filename `${filename_prefix}_wrapper.o`
14981 *
14982 * Note that if `targets.size()` > 1 and `no_runtime` is not specified, the runtime
14983 * will be generated with the filename `${filename_prefix}_runtime.o`
14984 */
14985 void compile_to_multitarget_object_files(const std::string &filename_prefix,
14986 const std::vector<Argument> &args,
14987 const std::vector<Target> &targets,
14988 const std::vector<std::string> &suffixes);
14989
14990 /** Create an internal representation of lowered code as a self
14991 * contained Module suitable for further compilation. */
14992 Module compile_to_module(const std::vector<Argument> &args,
14993 const std::string &fn_name,
14994 const Target &target = get_target_from_environment(),
14995 LinkageType linkage_type = LinkageType::ExternalPlusMetadata);
14996
14997 /** Eagerly jit compile the function to machine code. This
14998 * normally happens on the first call to realize. If you're
14999 * running your halide pipeline inside time-sensitive code and
15000 * wish to avoid including the time taken to compile a pipeline,
15001 * then you can call this ahead of time. Default is to use the Target
15002 * returned from Halide::get_jit_target_from_environment()
15003 */
15004 void compile_jit(const Target &target = get_jit_target_from_environment());
15005
15006 /** Set the error handler function that be called in the case of
15007 * runtime errors during halide pipelines. If you are compiling
15008 * statically, you can also just define your own function with
15009 * signature
15010 \code
15011 extern "C" void halide_error(void *user_context, const char *);
15012 \endcode
15013 * This will clobber Halide's version.
15014 */
15015 void set_error_handler(void (*handler)(void *, const char *));
15016
15017 /** Set a custom malloc and free for halide to use. Malloc should
15018 * return 32-byte aligned chunks of memory, and it should be safe
15019 * for Halide to read slightly out of bounds (up to 8 bytes before
15020 * the start or beyond the end). If compiling statically, routines
15021 * with appropriate signatures can be provided directly
15022 \code
15023 extern "C" void *halide_malloc(void *, size_t)
15024 extern "C" void halide_free(void *, void *)
15025 \endcode
15026 * These will clobber Halide's versions. See HalideRuntime.h
15027 * for declarations.
15028 */
15029 void set_custom_allocator(void *(*malloc)(void *, size_t),
15030 void (*free)(void *, void *));
15031
15032 /** Set a custom task handler to be called by the parallel for
15033 * loop. It is useful to set this if you want to do some
15034 * additional bookkeeping at the granularity of parallel
15035 * tasks. The default implementation does this:
15036 \code
15037 extern "C" int halide_do_task(void *user_context,
15038 int (*f)(void *, int, uint8_t *),
15039 int idx, uint8_t *state) {
15040 return f(user_context, idx, state);
15041 }
15042 \endcode
15043 * If you are statically compiling, you can also just define your
15044 * own version of the above function, and it will clobber Halide's
15045 * version.
15046 *
15047 * If you're trying to use a custom parallel runtime, you probably
15048 * don't want to call this. See instead \ref Func::set_custom_do_par_for .
15049 */
15050 void set_custom_do_task(
15051 int (*custom_do_task)(void *, int (*)(void *, int, uint8_t *),
15052 int, uint8_t *));
15053
15054 /** Set a custom parallel for loop launcher. Useful if your app
15055 * already manages a thread pool. The default implementation is
15056 * equivalent to this:
15057 \code
15058 extern "C" int halide_do_par_for(void *user_context,
15059 int (*f)(void *, int, uint8_t *),
15060 int min, int extent, uint8_t *state) {
15061 int exit_status = 0;
15062 parallel for (int idx = min; idx < min+extent; idx++) {
15063 int job_status = halide_do_task(user_context, f, idx, state);
15064 if (job_status) exit_status = job_status;
15065 }
15066 return exit_status;
15067 }
15068 \endcode
15069 *
15070 * However, notwithstanding the above example code, if one task
15071 * fails, we may skip over other tasks, and if two tasks return
15072 * different error codes, we may select one arbitrarily to return.
15073 *
15074 * If you are statically compiling, you can also just define your
15075 * own version of the above function, and it will clobber Halide's
15076 * version.
15077 */
15078 void set_custom_do_par_for(
15079 int (*custom_do_par_for)(void *, int (*)(void *, int, uint8_t *), int,
15080 int, uint8_t *));
15081
15082 /** Set custom routines to call when tracing is enabled. Call this
15083 * on the output Func of your pipeline. This then sets custom
15084 * routines for the entire pipeline, not just calls to this
15085 * Func.
15086 *
15087 * If you are statically compiling, you can also just define your
15088 * own versions of the tracing functions (see HalideRuntime.h),
15089 * and they will clobber Halide's versions. */
15090 void set_custom_trace(int (*trace_fn)(void *, const halide_trace_event_t *));
15091
15092 /** Set the function called to print messages from the runtime.
15093 * If you are compiling statically, you can also just define your
15094 * own function with signature
15095 \code
15096 extern "C" void halide_print(void *user_context, const char *);
15097 \endcode
15098 * This will clobber Halide's version.
15099 */
15100 void set_custom_print(void (*handler)(void *, const char *));
15101
15102 /** Install a set of external C functions or Funcs to satisfy
15103 * dependencies introduced by HalideExtern and define_extern
15104 * mechanisms. These will be used by calls to realize,
15105 * infer_bounds, and compile_jit. */
15106 void set_jit_externs(const std::map<std::string, JITExtern> &externs);
15107
15108 /** Return the map of previously installed externs. Is an empty
15109 * map unless set otherwise. */
15110 const std::map<std::string, JITExtern> &get_jit_externs();
15111
15112 /** Get a struct containing the currently set custom functions
15113 * used by JIT. */
15114 const Internal::JITHandlers &jit_handlers();
15115
15116 /** Add a custom pass to be used during lowering. It is run after
15117 * all other lowering passes. Can be used to verify properties of
15118 * the lowered Stmt, instrument it with extra code, or otherwise
15119 * modify it. The Func takes ownership of the pass, and will call
15120 * delete on it when the Func goes out of scope. So don't pass a
15121 * stack object, or share pass instances between multiple
15122 * Funcs. */
15123 template<typename T>
15124 void add_custom_lowering_pass(T *pass) {
15125 // Template instantiate a custom deleter for this type, then
15126 // wrap in a lambda. The custom deleter lives in user code, so
15127 // that deletion is on the same heap as construction (I hate Windows).
15128 add_custom_lowering_pass(pass, [pass]() { delete_lowering_pass<T>(pass); });
15129 }
15130
15131 /** Add a custom pass to be used during lowering, with the
15132 * function that will be called to delete it also passed in. Set
15133 * it to nullptr if you wish to retain ownership of the object. */
15134 void add_custom_lowering_pass(Internal::IRMutator *pass, std::function<void()> deleter);
15135
15136 /** Remove all previously-set custom lowering passes */
15137 void clear_custom_lowering_passes();
15138
15139 /** Get the custom lowering passes. */
15140 const std::vector<CustomLoweringPass> &custom_lowering_passes();
15141
15142 /** See Func::realize */
15143 // @{
15144 Realization realize(std::vector<int32_t> sizes = {}, const Target &target = Target(),
15145 const ParamMap &param_map = ParamMap::empty_map());
15146 HALIDE_ATTRIBUTE_DEPRECATED("Call realize() with a vector<int> instead")
15147 Realization realize(int x_size, int y_size, int z_size, int w_size, const Target &target = Target(),
15148 const ParamMap &param_map = ParamMap::empty_map());
15149 HALIDE_ATTRIBUTE_DEPRECATED("Call realize() with a vector<int> instead")
15150 Realization realize(int x_size, int y_size, int z_size, const Target &target = Target(),
15151 const ParamMap &param_map = ParamMap::empty_map());
15152 HALIDE_ATTRIBUTE_DEPRECATED("Call realize() with a vector<int> instead")
15153 Realization realize(int x_size, int y_size, const Target &target = Target(),
15154 const ParamMap &param_map = ParamMap::empty_map());
15155
15156 // Making this a template function is a trick: `{intliteral}` is a valid scalar initializer
15157 // in C++, but we want it to match the vector call, not the (deprecated) scalar one.
15158 template<typename T, typename = typename std::enable_if<std::is_same<T, int>::value>::type>
15159 HALIDE_ATTRIBUTE_DEPRECATED("Call realize() with a vector<int> instead")
15160 HALIDE_ALWAYS_INLINE Realization realize(T x_size, const Target &target = Target(),
15161 const ParamMap &param_map = ParamMap::empty_map()) {
15162 return realize(std::vector<int32_t>{x_size}, target, param_map);
15163 }
15164 // @}
15165
15166 /** Evaluate this Pipeline into an existing allocated buffer or
15167 * buffers. If the buffer is also one of the arguments to the
15168 * function, strange things may happen, as the pipeline isn't
15169 * necessarily safe to run in-place. The realization should
15170 * contain one Buffer per tuple component per output Func. For
15171 * each individual output Func, all Buffers must have the same
15172 * shape, but the shape can vary across the different output
15173 * Funcs. This form of realize does *not* automatically copy data
15174 * back from the GPU. */
15175 void realize(RealizationArg output, const Target &target = Target(),
15176 const ParamMap &param_map = ParamMap::empty_map());
15177
15178 /** For a given size of output, or a given set of output buffers,
15179 * determine the bounds required of all unbound ImageParams
15180 * referenced. Communicates the result by allocating new buffers
15181 * of the appropriate size and binding them to the unbound
15182 * ImageParams. */
15183 // @{
15184 void infer_input_bounds(const std::vector<int32_t> &sizes,
15185 const Target &target = get_jit_target_from_environment(),
15186 const ParamMap &param_map = ParamMap::empty_map());
15187 void infer_input_bounds(RealizationArg output,
15188 const Target &target = get_jit_target_from_environment(),
15189 const ParamMap &param_map = ParamMap::empty_map());
15190 // @}
15191
15192 /** Infer the arguments to the Pipeline, sorted into a canonical order:
15193 * all buffers (sorted alphabetically by name), followed by all non-buffers
15194 * (sorted alphabetically by name).
15195 This lets you write things like:
15196 \code
15197 pipeline.compile_to_assembly("/dev/stdout", pipeline.infer_arguments());
15198 \endcode
15199 */
15200 std::vector<Argument> infer_arguments();
15201
15202 /** Check if this pipeline object is defined. That is, does it
15203 * have any outputs? */
15204 bool defined() const;
15205
15206 /** Invalidate any internal cached state, e.g. because Funcs have
15207 * been rescheduled. */
15208 void invalidate_cache();
15209
15210 /** Add a top-level precondition to the generated pipeline,
15211 * expressed as a boolean Expr. The Expr may depend on parameters
15212 * only, and may not call any Func or use a Var. If the condition
15213 * is not true at runtime, the pipeline will call halide_error
15214 * with the remaining arguments, and return
15215 * halide_error_code_requirement_failed. Requirements are checked
15216 * in the order added. */
15217 void add_requirement(const Expr &condition, std::vector<Expr> &error);
15218
15219 /** Generate begin_pipeline and end_pipeline tracing calls for this pipeline. */
15220 void trace_pipeline();
15221
15222 template<typename... Args>
15223 inline HALIDE_NO_USER_CODE_INLINE void add_requirement(const Expr &condition, Args &&...args) {
15224 std::vector<Expr> collected_args;
15225 Internal::collect_print_args(collected_args, std::forward<Args>(args)...);
15226 add_requirement(condition, collected_args);
15227 }
15228
15229private:
15230 std::string generate_function_name() const;
15231};
15232
15233struct ExternSignature {
15234private:
15235 Type ret_type_; // Only meaningful if is_void_return is false; must be default value otherwise
15236 bool is_void_return_{false};
15237 std::vector<Type> arg_types_;
15238
15239public:
15240 ExternSignature() = default;
15241
15242 ExternSignature(const Type &ret_type, bool is_void_return, const std::vector<Type> &arg_types)
15243 : ret_type_(ret_type),
15244 is_void_return_(is_void_return),
15245 arg_types_(arg_types) {
15246 internal_assert(!(is_void_return && ret_type != Type()));
15247 }
15248
15249 template<typename RT, typename... Args>
15250 explicit ExternSignature(RT (*f)(Args... args))
15251 : ret_type_(type_of<RT>()),
15252 is_void_return_(std::is_void<RT>::value),
15253 arg_types_({type_of<Args>()...}) {
15254 }
15255
15256 const Type &ret_type() const {
15257 internal_assert(!is_void_return());
15258 return ret_type_;
15259 }
15260
15261 bool is_void_return() const {
15262 return is_void_return_;
15263 }
15264
15265 const std::vector<Type> &arg_types() const {
15266 return arg_types_;
15267 }
15268
15269 friend std::ostream &operator<<(std::ostream &stream, const ExternSignature &sig) {
15270 if (sig.is_void_return_) {
15271 stream << "void";
15272 } else {
15273 stream << sig.ret_type_;
15274 }
15275 stream << " (*)(";
15276 bool comma = false;
15277 for (const auto &t : sig.arg_types_) {
15278 if (comma) {
15279 stream << ", ";
15280 }
15281 stream << t;
15282 comma = true;
15283 }
15284 stream << ")";
15285 return stream;
15286 }
15287};
15288
15289struct ExternCFunction {
15290private:
15291 void *address_{nullptr};
15292 ExternSignature signature_;
15293
15294public:
15295 ExternCFunction() = default;
15296
15297 ExternCFunction(void *address, const ExternSignature &signature)
15298 : address_(address), signature_(signature) {
15299 }
15300
15301 template<typename RT, typename... Args>
15302 ExternCFunction(RT (*f)(Args... args))
15303 : ExternCFunction((void *)f, ExternSignature(f)) {
15304 }
15305
15306 void *address() const {
15307 return address_;
15308 }
15309 const ExternSignature &signature() const {
15310 return signature_;
15311 }
15312};
15313
15314struct JITExtern {
15315private:
15316 // Note that exactly one of pipeline_ and extern_c_function_
15317 // can be set in a given JITExtern instance.
15318 Pipeline pipeline_;
15319 ExternCFunction extern_c_function_;
15320
15321public:
15322 explicit JITExtern(Pipeline pipeline);
15323 explicit JITExtern(const Func &func);
15324 explicit JITExtern(const ExternCFunction &extern_c_function);
15325
15326 template<typename RT, typename... Args>
15327 explicit JITExtern(RT (*f)(Args... args))
15328 : JITExtern(ExternCFunction(f)) {
15329 }
15330
15331 const Pipeline &pipeline() const {
15332 return pipeline_;
15333 }
15334 const ExternCFunction &extern_c_function() const {
15335 return extern_c_function_;
15336 }
15337};
15338
15339} // namespace Halide
15340
15341#endif
15342#ifndef HALIDE_RDOM_H
15343#define HALIDE_RDOM_H
15344
15345/** \file
15346 * Defines the front-end syntax for reduction domains and reduction
15347 * variables.
15348 */
15349
15350#include <iostream>
15351#include <string>
15352#include <utility>
15353#include <vector>
15354
15355
15356namespace Halide {
15357
15358template<typename T>
15359class Buffer;
15360class OutputImageParam;
15361
15362/** A reduction variable represents a single dimension of a reduction
15363 * domain (RDom). Don't construct them directly, instead construct an
15364 * RDom, and use RDom::operator[] to get at the variables. For
15365 * single-dimensional reduction domains, you can just cast a
15366 * single-dimensional RDom to an RVar. */
15367class RVar {
15368 std::string _name;
15369 Internal::ReductionDomain _domain;
15370 int _index = -1;
15371
15372 const Internal::ReductionVariable &_var() const {
15373 const auto &d = _domain.domain();
15374 internal_assert(_index >= 0 && _index < (int)d.size());
15375 return d.at(_index);
15376 }
15377
15378public:
15379 /** An empty reduction variable. */
15380 RVar()
15381 : _name(Internal::make_entity_name(this, "Halide:.*:RVar", 'r')) {
15382 }
15383
15384 /** Construct an RVar with the given name */
15385 explicit RVar(const std::string &n)
15386 : _name(n) {
15387 }
15388
15389 /** Construct a reduction variable with the given name and
15390 * bounds. Must be a member of the given reduction domain. */
15391 RVar(Internal::ReductionDomain domain, int index)
15392 : _domain(std::move(domain)), _index(index) {
15393 }
15394
15395 /** The minimum value that this variable will take on */
15396 Expr min() const;
15397
15398 /** The number that this variable will take on. The maximum value
15399 * of this variable will be min() + extent() - 1 */
15400 Expr extent() const;
15401
15402 /** The reduction domain this is associated with. */
15403 Internal::ReductionDomain domain() const {
15404 return _domain;
15405 }
15406
15407 /** The name of this reduction variable */
15408 const std::string &name() const;
15409
15410 /** Reduction variables can be used as expressions. */
15411 operator Expr() const;
15412};
15413
15414/** A multi-dimensional domain over which to iterate. Used when
15415 * defining functions with update definitions.
15416 *
15417 * An reduction is a function with a two-part definition. It has an
15418 * initial value, which looks much like a pure function, and an update
15419 * definition, which may refer to some RDom. Evaluating such a
15420 * function first initializes it over the required domain (which is
15421 * inferred based on usage), and then runs update rule for all points
15422 * in the RDom. For example:
15423 *
15424 \code
15425 Func f;
15426 Var x;
15427 RDom r(0, 10);
15428 f(x) = x; // the initial value
15429 f(r) = f(r) * 2;
15430 Buffer<int> result = f.realize({10});
15431 \endcode
15432 *
15433 * This function creates a single-dimensional buffer of size 10, in
15434 * which element x contains the value x*2. Internally, first the
15435 * initialization rule fills in x at every site, and then the update
15436 * definition doubles every site.
15437 *
15438 * One use of reductions is to build a function recursively (pure
15439 * functions in halide cannot be recursive). For example, this
15440 * function fills in an array with the first 20 fibonacci numbers:
15441 *
15442 \code
15443 Func f;
15444 Var x;
15445 RDom r(2, 18);
15446 f(x) = 1;
15447 f(r) = f(r-1) + f(r-2);
15448 \endcode
15449 *
15450 * Another use of reductions is to perform scattering operations, as
15451 * unlike a pure function declaration, the left-hand-side of an update
15452 * definition may contain general expressions:
15453 *
15454 \code
15455 ImageParam input(UInt(8), 2);
15456 Func histogram;
15457 Var x;
15458 RDom r(input); // Iterate over all pixels in the input
15459 histogram(x) = 0;
15460 histogram(input(r.x, r.y)) = histogram(input(r.x, r.y)) + 1;
15461 \endcode
15462 *
15463 * An update definition may also be multi-dimensional. This example
15464 * computes a summed-area table by first summing horizontally and then
15465 * vertically:
15466 *
15467 \code
15468 ImageParam input(Float(32), 2);
15469 Func sum_x, sum_y;
15470 Var x, y;
15471 RDom r(input);
15472 sum_x(x, y) = input(x, y);
15473 sum_x(r.x, r.y) = sum_x(r.x, r.y) + sum_x(r.x-1, r.y);
15474 sum_y(x, y) = sum_x(x, y);
15475 sum_y(r.x, r.y) = sum_y(r.x, r.y) + sum_y(r.x, r.y-1);
15476 \endcode
15477 *
15478 * You can also mix pure dimensions with reduction variables. In the
15479 * previous example, note that there's no need for the y coordinate in
15480 * sum_x to be traversed serially. The sum within each row is entirely
15481 * independent. The rows could be computed in parallel, or in a
15482 * different order, without changing the meaning. Therefore, we can
15483 * instead write this definition as follows:
15484 *
15485 \code
15486 ImageParam input(Float(32), 2);
15487 Func sum_x, sum_y;
15488 Var x, y;
15489 RDom r(input);
15490 sum_x(x, y) = input(x, y);
15491 sum_x(r.x, y) = sum_x(r.x, y) + sum_x(r.x-1, y);
15492 sum_y(x, y) = sum_x(x, y);
15493 sum_y(x, r.y) = sum_y(x, r.y) + sum_y(x, r.y-1);
15494 \endcode
15495 *
15496 * This lets us schedule it more flexibly. You can now parallelize the
15497 * update step of sum_x over y by calling:
15498 \code
15499 sum_x.update().parallel(y).
15500 \endcode
15501 *
15502 * Note that calling sum_x.parallel(y) only parallelizes the
15503 * initialization step, and not the update step! Scheduling the update
15504 * step of a reduction must be done using the handle returned by
15505 * \ref Func::update(). This code parallelizes both the initialization
15506 * step and the update step:
15507 *
15508 \code
15509 sum_x.parallel(y);
15510 sum_x.update().parallel(y);
15511 \endcode
15512 *
15513 * When you mix reduction variables and pure dimensions, the reduction
15514 * domain is traversed outermost. That is, for each point in the
15515 * reduction domain, the inferred pure domain is traversed in its
15516 * entirety. For the above example, this means that sum_x walks down
15517 * the columns, and sum_y walks along the rows. This may not be
15518 * cache-coherent. You may try reordering these dimensions using the
15519 * schedule, but Halide will return an error if it decides that this
15520 * risks changing the meaning of your function. The solution lies in
15521 * clever scheduling. If we say:
15522 *
15523 \code
15524 sum_x.compute_at(sum_y, y);
15525 \endcode
15526 *
15527 * Then the sum in x is computed only as necessary for each scanline
15528 * of the sum in y. This not only results in sum_x walking along the
15529 * rows, it also improves the locality of the entire pipeline.
15530 */
15531class RDom {
15532 Internal::ReductionDomain dom;
15533
15534 void init_vars(const std::string &name);
15535
15536 void initialize_from_region(const Region &region, std::string name = "");
15537
15538 template<typename... Args>
15539 HALIDE_NO_USER_CODE_INLINE void initialize_from_region(Region &region, const Expr &min, const Expr &extent, Args &&...args) {
15540 region.push_back({min, extent});
15541 initialize_from_region(region, std::forward<Args>(args)...);
15542 }
15543
15544public:
15545 /** Construct an undefined reduction domain. */
15546 RDom() = default;
15547
15548 /** Construct a multi-dimensional reduction domain with the given name. If the name
15549 * is left blank, a unique one is auto-generated. */
15550 // @{
15551 HALIDE_NO_USER_CODE_INLINE RDom(const Region &region, std::string name = "") {
15552 initialize_from_region(region, std::move(name));
15553 }
15554
15555 template<typename... Args>
15556 HALIDE_NO_USER_CODE_INLINE RDom(Expr min, Expr extent, Args &&...args) {
15557 // This should really just be a delegating constructor, but I couldn't make
15558 // that work with variadic template unpacking in visual studio 2013
15559 Region region;
15560 initialize_from_region(region, min, extent, std::forward<Args>(args)...);
15561 }
15562 // @}
15563
15564 /** Construct a reduction domain that iterates over all points in
15565 * a given Buffer or ImageParam. Has the same dimensionality as
15566 * the argument. */
15567 // @{
15568 RDom(const Buffer<void> &);
15569 RDom(const OutputImageParam &);
15570 template<typename T>
15571 HALIDE_NO_USER_CODE_INLINE RDom(const Buffer<T> &im)
15572 : RDom(Buffer<void>(im)) {
15573 }
15574 // @}
15575
15576 /** Construct a reduction domain that wraps an Internal ReductionDomain object. */
15577 RDom(const Internal::ReductionDomain &d);
15578
15579 /** Get at the internal reduction domain object that this wraps. */
15580 Internal::ReductionDomain domain() const {
15581 return dom;
15582 }
15583
15584 /** Check if this reduction domain is non-null */
15585 bool defined() const {
15586 return dom.defined();
15587 }
15588
15589 /** Compare two reduction domains for equality of reference */
15590 bool same_as(const RDom &other) const {
15591 return dom.same_as(other.dom);
15592 }
15593
15594 /** Get the dimensionality of a reduction domain */
15595 int dimensions() const;
15596
15597 /** Get at one of the dimensions of the reduction domain */
15598 RVar operator[](int) const;
15599
15600 /** Single-dimensional reduction domains can be used as RVars directly. */
15601 operator RVar() const;
15602
15603 /** Single-dimensional reduction domains can be also be used as Exprs directly. */
15604 operator Expr() const;
15605
15606 /** Add a predicate to the RDom. An RDom may have multiple
15607 * predicates associated with it. An update definition that uses
15608 * an RDom only iterates over the subset points in the domain for
15609 * which all of its predicates are true. The predicate expression
15610 * obeys the same rules as the expressions used on the
15611 * right-hand-side of the corresponding update definition. It may
15612 * refer to the RDom's variables and free variables in the Func's
15613 * update definition. It may include calls to other Funcs, or make
15614 * recursive calls to the same Func. This permits iteration over
15615 * non-rectangular domains, or domains with sizes that vary with
15616 * some free variable, or domains with shapes determined by some
15617 * other Func.
15618 *
15619 * Note that once RDom is used in the update definition of some
15620 * Func, no new predicates can be added to the RDom.
15621 *
15622 * Consider a simple example:
15623 \code
15624 RDom r(0, 20, 0, 20);
15625 r.where(r.x < r.y);
15626 r.where(r.x == 10);
15627 r.where(r.y > 13);
15628 f(r.x, r.y) += 1;
15629 \endcode
15630 * This is equivalent to:
15631 \code
15632 for (int r.y = 0; r.y < 20; r.y++) {
15633 if (r.y > 13) {
15634 for (int r.x = 0; r.x < 20; r.x++) {
15635 if (r.x == 10) {
15636 if (r.x < r.y) {
15637 f[r.x, r.y] += 1;
15638 }
15639 }
15640 }
15641 }
15642 }
15643 \endcode
15644 *
15645 * Where possible Halide restricts the range of the containing for
15646 * loops to avoid the cases where the predicate is false so that
15647 * the if statement can be removed entirely. The case above would
15648 * be further simplified into:
15649 *
15650 \code
15651 for (int r.y = 14; r.y < 20; r.y++) {
15652 f[r.x, r.y] += 1;
15653 }
15654 \endcode
15655 *
15656 * In general, the predicates that we can simplify away by
15657 * restricting loop ranges are inequalities that compare an inner
15658 * Var or RVar to some expression in outer Vars or RVars.
15659 *
15660 * You can also pack multiple conditions into one predicate like so:
15661 *
15662 \code
15663 RDom r(0, 20, 0, 20);
15664 r.where((r.x < r.y) && (r.x == 10) && (r.y > 13));
15665 f(r.x, r.y) += 1;
15666 \endcode
15667 *
15668 */
15669 void where(Expr predicate);
15670
15671 /** Direct access to the first four dimensions of the reduction
15672 * domain. Some of these variables may be undefined if the
15673 * reduction domain has fewer than four dimensions. */
15674 // @{
15675 RVar x, y, z, w;
15676 // @}
15677};
15678
15679/** Emit an RVar in a human-readable form */
15680std::ostream &operator<<(std::ostream &stream, const RVar &);
15681
15682/** Emit an RDom in a human-readable form. */
15683std::ostream &operator<<(std::ostream &stream, const RDom &);
15684} // namespace Halide
15685
15686#endif
15687#ifndef HALIDE_VAR_H
15688#define HALIDE_VAR_H
15689
15690/** \file
15691 * Defines the Var - the front-end variable
15692 */
15693#include <string>
15694#include <vector>
15695
15696
15697namespace Halide {
15698
15699/** A Halide variable, to be used when defining functions. It is just
15700 * a name, and can be reused in places where no name conflict will
15701 * occur. It can be used in the left-hand-side of a function
15702 * definition, or as an Expr. As an Expr, it always has type
15703 * Int(32). */
15704class Var {
15705 /* The expression representing the Var. Guaranteed to be an
15706 * Internal::Variable of type Int(32). Created once on
15707 * construction of the Var to avoid making a fresh Expr every time
15708 * the Var is used in a context in which is will be converted to
15709 * one. */
15710 Expr e;
15711
15712public:
15713 /** Construct a Var with the given name */
15714 Var(const std::string &n);
15715
15716 /** Construct a Var with an automatically-generated unique name. */
15717 Var();
15718
15719 /** Get the name of a Var */
15720 const std::string &name() const;
15721
15722 /** Test if two Vars are the same. This simply compares the names. */
15723 bool same_as(const Var &other) const {
15724 return name() == other.name();
15725 }
15726
15727 /** Implicit var constructor. Implicit variables are injected
15728 * automatically into a function call if the number of arguments
15729 * to the function are fewer than its dimensionality and a
15730 * placeholder ("_") appears in its argument list. Defining a
15731 * function to equal an expression containing implicit variables
15732 * similarly appends those implicit variables, in the same order,
15733 * to the left-hand-side of the definition where the placeholder
15734 * ('_') appears.
15735 *
15736 * For example, consider the definition:
15737 *
15738 \code
15739 Func f, g;
15740 Var x, y;
15741 f(x, y) = 3;
15742 \endcode
15743 *
15744 * A call to f with the placeholder symbol _
15745 * will have implicit arguments injected automatically, so f(2, _)
15746 * is equivalent to f(2, _0), where _0 = ImplicitVar<0>(), and f(_)
15747 * (and indeed f when cast to an Expr) is equivalent to f(_0, _1).
15748 * The following definitions are all equivalent, differing only in the
15749 * variable names.
15750 *
15751 \code
15752 g(_) = f*3;
15753 g(_) = f(_)*3;
15754 g(x, _) = f(x, _)*3;
15755 g(x, y) = f(x, y)*3;
15756 \endcode
15757 *
15758 * These are expanded internally as follows:
15759 *
15760 \code
15761 g(_0, _1) = f(_0, _1)*3;
15762 g(_0, _1) = f(_0, _1)*3;
15763 g(x, _0) = f(x, _0)*3;
15764 g(x, y) = f(x, y)*3;
15765 \endcode
15766 *
15767 * The following, however, defines g as four dimensional:
15768 \code
15769 g(x, y, _) = f*3;
15770 \endcode
15771 *
15772 * It is equivalent to:
15773 *
15774 \code
15775 g(x, y, _0, _1) = f(_0, _1)*3;
15776 \endcode
15777 *
15778 * Expressions requiring differing numbers of implicit variables
15779 * can be combined. The left-hand-side of a definition injects
15780 * enough implicit variables to cover all of them:
15781 *
15782 \code
15783 Func h;
15784 h(x) = x*3;
15785 g(x) = h + (f + f(x)) * f(x, y);
15786 \endcode
15787 *
15788 * expands to:
15789 *
15790 \code
15791 Func h;
15792 h(x) = x*3;
15793 g(x, _0, _1) = h(_0) + (f(_0, _1) + f(x, _0)) * f(x, y);
15794 \endcode
15795 *
15796 * The first ten implicits, _0 through _9, are predeclared in this
15797 * header and can be used for scheduling. They should never be
15798 * used as arguments in a declaration or used in a call.
15799 *
15800 * While it is possible to use Var::implicit or the predeclared
15801 * implicits to create expressions that can be treated as small
15802 * anonymous functions (e.g. Func(_0 + _1)) this is considered
15803 * poor style. Instead use \ref lambda.
15804 */
15805 static Var implicit(int n);
15806
15807 /** Return whether a variable name is of the form for an implicit argument.
15808 * TODO: This is almost guaranteed to incorrectly fire on user
15809 * declared variables at some point. We should likely prevent
15810 * user Var declarations from making names of this form.
15811 */
15812 //{
15813 static bool is_implicit(const std::string &name);
15814 bool is_implicit() const {
15815 return is_implicit(name());
15816 }
15817 //}
15818
15819 /** Return the argument index for a placeholder argument given its
15820 * name. Returns 0 for _0, 1 for _1, etc. Returns -1 if
15821 * the variable is not of implicit form.
15822 */
15823 //{
15824 static int implicit_index(const std::string &name) {
15825 return is_implicit(name) ? atoi(name.c_str() + 1) : -1;
15826 }
15827 int implicit_index() const {
15828 return implicit_index(name());
15829 }
15830 //}
15831
15832 /** Test if a var is the placeholder variable _ */
15833 //{
15834 static bool is_placeholder(const std::string &name) {
15835 return name == "_";
15836 }
15837 bool is_placeholder() const {
15838 return is_placeholder(name());
15839 }
15840 //}
15841
15842 /** A Var can be treated as an Expr of type Int(32) */
15843 operator const Expr &() const {
15844 return e;
15845 }
15846
15847 /** A Var that represents the location outside the outermost loop. */
15848 static Var outermost() {
15849 return Var("__outermost");
15850 }
15851};
15852
15853template<int N = -1>
15854struct ImplicitVar {
15855 Var to_var() const {
15856 if (N >= 0) {
15857 return Var::implicit(N);
15858 } else {
15859 return Var("_");
15860 }
15861 }
15862
15863 operator Var() const {
15864 return to_var();
15865 }
15866 operator Expr() const {
15867 return to_var();
15868 }
15869};
15870
15871/** A placeholder variable for inferred arguments. See \ref Var::implicit */
15872static constexpr ImplicitVar<> _;
15873
15874/** The first ten implicit Vars for use in scheduling. See \ref Var::implicit */
15875// @{
15876static constexpr ImplicitVar<0> _0;
15877static constexpr ImplicitVar<1> _1;
15878static constexpr ImplicitVar<2> _2;
15879static constexpr ImplicitVar<3> _3;
15880static constexpr ImplicitVar<4> _4;
15881static constexpr ImplicitVar<5> _5;
15882static constexpr ImplicitVar<6> _6;
15883static constexpr ImplicitVar<7> _7;
15884static constexpr ImplicitVar<8> _8;
15885static constexpr ImplicitVar<9> _9;
15886// @}
15887
15888namespace Internal {
15889
15890/** Make a list of unique arguments for definitions with unnamed
15891 arguments. */
15892std::vector<Var> make_argument_list(int dimensionality);
15893
15894} // namespace Internal
15895
15896} // namespace Halide
15897
15898#endif
15899
15900#include <map>
15901#include <utility>
15902
15903namespace Halide {
15904
15905class OutputImageParam;
15906class ParamMap;
15907
15908/** A class that can represent Vars or RVars. Used for reorder calls
15909 * which can accept a mix of either. */
15910struct VarOrRVar {
15911 VarOrRVar(const std::string &n, bool r)
15912 : var(n), rvar(n), is_rvar(r) {
15913 }
15914 VarOrRVar(const Var &v)
15915 : var(v), is_rvar(false) {
15916 }
15917 VarOrRVar(const RVar &r)
15918 : rvar(r), is_rvar(true) {
15919 }
15920 VarOrRVar(const RDom &r)
15921 : rvar(RVar(r)), is_rvar(true) {
15922 }
15923 template<int N>
15924 VarOrRVar(const ImplicitVar<N> &u)
15925 : var(u), is_rvar(false) {
15926 }
15927
15928 const std::string &name() const {
15929 if (is_rvar) {
15930 return rvar.name();
15931 } else {
15932 return var.name();
15933 }
15934 }
15935
15936 Var var;
15937 RVar rvar;
15938 bool is_rvar;
15939};
15940
15941class ImageParam;
15942
15943namespace Internal {
15944class Function;
15945struct Split;
15946struct StorageDim;
15947} // namespace Internal
15948
15949/** A single definition of a Func. May be a pure or update definition. */
15950class Stage {
15951 /** Reference to the Function this stage (or definition) belongs to. */
15952 Internal::Function function;
15953 Internal::Definition definition;
15954 /** Indicate which stage the definition belongs to (0 for initial
15955 * definition, 1 for first update, etc.). */
15956 size_t stage_index;
15957 /** Pure Vars of the Function (from the init definition). */
15958 std::vector<Var> dim_vars;
15959
15960 void set_dim_type(const VarOrRVar &var, Internal::ForType t);
15961 void set_dim_device_api(const VarOrRVar &var, DeviceAPI device_api);
15962 void split(const std::string &old, const std::string &outer, const std::string &inner,
15963 const Expr &factor, bool exact, TailStrategy tail);
15964 void remove(const std::string &var);
15965 Stage &purify(const VarOrRVar &old_name, const VarOrRVar &new_name);
15966
15967 const std::vector<Internal::StorageDim> &storage_dims() const {
15968 return function.schedule().storage_dims();
15969 }
15970
15971 Stage &compute_with(LoopLevel loop_level, const std::map<std::string, LoopAlignStrategy> &align);
15972
15973public:
15974 Stage(Internal::Function f, Internal::Definition d, size_t stage_index)
15975 : function(std::move(f)), definition(std::move(d)), stage_index(stage_index) {
15976 internal_assert(definition.defined());
15977 definition.schedule().touched() = true;
15978
15979 dim_vars.reserve(function.args().size());
15980 for (const auto &arg : function.args()) {
15981 dim_vars.emplace_back(arg);
15982 }
15983 internal_assert(definition.args().size() == dim_vars.size());
15984 }
15985
15986 /** Return the current StageSchedule associated with this Stage. For
15987 * introspection only: to modify schedule, use the Func interface. */
15988 const Internal::StageSchedule &get_schedule() const {
15989 return definition.schedule();
15990 }
15991
15992 /** Return a string describing the current var list taking into
15993 * account all the splits, reorders, and tiles. */
15994 std::string dump_argument_list() const;
15995
15996 /** Return the name of this stage, e.g. "f.update(2)" */
15997 std::string name() const;
15998
15999 /** Calling rfactor() on an associative update definition a Func will split
16000 * the update into an intermediate which computes the partial results and
16001 * replaces the current update definition with a new definition which merges
16002 * the partial results. If called on a init/pure definition, this will
16003 * throw an error. rfactor() will automatically infer the associative reduction
16004 * operator and identity of the operator. If it can't prove the operation
16005 * is associative or if it cannot find an identity for that operator, this
16006 * will throw an error. In addition, commutativity of the operator is required
16007 * if rfactor() is called on the inner dimension but excluding the outer
16008 * dimensions.
16009 *
16010 * rfactor() takes as input 'preserved', which is a list of <RVar, Var> pairs.
16011 * The rvars not listed in 'preserved' are removed from the original Func and
16012 * are lifted to the intermediate Func. The remaining rvars (the ones in
16013 * 'preserved') are made pure in the intermediate Func. The intermediate Func's
16014 * update definition inherits all scheduling directives (e.g. split,fuse, etc.)
16015 * applied to the original Func's update definition. The loop order of the
16016 * intermediate Func's update definition is the same as the original, although
16017 * the RVars in 'preserved' are replaced by the new pure Vars. The loop order of the
16018 * intermediate Func's init definition from innermost to outermost is the args'
16019 * order of the original Func's init definition followed by the new pure Vars.
16020 *
16021 * The intermediate Func also inherits storage order from the original Func
16022 * with the new pure Vars added to the outermost.
16023 *
16024 * For example, f.update(0).rfactor({{r.y, u}}) would rewrite a pipeline like this:
16025 \code
16026 f(x, y) = 0;
16027 f(x, y) += g(r.x, r.y);
16028 \endcode
16029 * into a pipeline like this:
16030 \code
16031 f_intm(x, y, u) = 0;
16032 f_intm(x, y, u) += g(r.x, u);
16033
16034 f(x, y) = 0;
16035 f(x, y) += f_intm(x, y, r.y);
16036 \endcode
16037 *
16038 * This has a variety of uses. You can use it to split computation of an associative reduction:
16039 \code
16040 f(x, y) = 10;
16041 RDom r(0, 96);
16042 f(x, y) = max(f(x, y), g(x, y, r.x));
16043 f.update(0).split(r.x, rxo, rxi, 8).reorder(y, x).parallel(x);
16044 f.update(0).rfactor({{rxo, u}}).compute_root().parallel(u).update(0).parallel(u);
16045 \endcode
16046 *
16047 *, which is equivalent to:
16048 \code
16049 parallel for u = 0 to 11:
16050 for y:
16051 for x:
16052 f_intm(x, y, u) = -inf
16053 parallel for x:
16054 for y:
16055 parallel for u = 0 to 11:
16056 for rxi = 0 to 7:
16057 f_intm(x, y, u) = max(f_intm(x, y, u), g(8*u + rxi))
16058 for y:
16059 for x:
16060 f(x, y) = 10
16061 parallel for x:
16062 for y:
16063 for rxo = 0 to 11:
16064 f(x, y) = max(f(x, y), f_intm(x, y, rxo))
16065 \endcode
16066 *
16067 */
16068 // @{
16069 Func rfactor(std::vector<std::pair<RVar, Var>> preserved);
16070 Func rfactor(const RVar &r, const Var &v);
16071 // @}
16072
16073 /** Schedule the iteration over this stage to be fused with another
16074 * stage 's' from outermost loop to a given LoopLevel. 'this' stage will
16075 * be computed AFTER 's' in the innermost fused dimension. There should not
16076 * be any dependencies between those two fused stages. If either of the
16077 * stages being fused is a stage of an extern Func, this will throw an error.
16078 *
16079 * Note that the two stages that are fused together should have the same
16080 * exact schedule from the outermost to the innermost fused dimension, and
16081 * the stage we are calling compute_with on should not have specializations,
16082 * e.g. f2.compute_with(f1, x) is allowed only if f2 has no specializations.
16083 *
16084 * Also, if a producer is desired to be computed at the fused loop level,
16085 * the function passed to the compute_at() needs to be the "parent". Consider
16086 * the following code:
16087 \code
16088 input(x, y) = x + y;
16089 f(x, y) = input(x, y);
16090 f(x, y) += 5;
16091 g(x, y) = x - y;
16092 g(x, y) += 10;
16093 f.compute_with(g, y);
16094 f.update().compute_with(g.update(), y);
16095 \endcode
16096 *
16097 * To compute 'input' at the fused loop level at dimension y, we specify
16098 * input.compute_at(g, y) instead of input.compute_at(f, y) since 'g' is
16099 * the "parent" for this fused loop (i.e. 'g' is computed first before 'f'
16100 * is computed). On the other hand, to compute 'input' at the innermost
16101 * dimension of 'f', we specify input.compute_at(f, x) instead of
16102 * input.compute_at(g, x) since the x dimension of 'f' is not fused
16103 * (only the y dimension is).
16104 *
16105 * Given the constraints, this has a variety of uses. Consider the
16106 * following code:
16107 \code
16108 f(x, y) = x + y;
16109 g(x, y) = x - y;
16110 h(x, y) = f(x, y) + g(x, y);
16111 f.compute_root();
16112 g.compute_root();
16113 f.split(x, xo, xi, 8);
16114 g.split(x, xo, xi, 8);
16115 g.compute_with(f, xo);
16116 \endcode
16117 *
16118 * This is equivalent to:
16119 \code
16120 for y:
16121 for xo:
16122 for xi:
16123 f(8*xo + xi) = (8*xo + xi) + y
16124 for xi:
16125 g(8*xo + xi) = (8*xo + xi) - y
16126 for y:
16127 for x:
16128 h(x, y) = f(x, y) + g(x, y)
16129 \endcode
16130 *
16131 * The size of the dimensions of the stages computed_with do not have
16132 * to match. Consider the following code where 'g' is half the size of 'f':
16133 \code
16134 Image<int> f_im(size, size), g_im(size/2, size/2);
16135 input(x, y) = x + y;
16136 f(x, y) = input(x, y);
16137 g(x, y) = input(2*x, 2*y);
16138 g.compute_with(f, y);
16139 input.compute_at(f, y);
16140 Pipeline({f, g}).realize({f_im, g_im});
16141 \endcode
16142 *
16143 * This is equivalent to:
16144 \code
16145 for y = 0 to size-1:
16146 for x = 0 to size-1:
16147 input(x, y) = x + y;
16148 for x = 0 to size-1:
16149 f(x, y) = input(x, y)
16150 for x = 0 to size/2-1:
16151 if (y < size/2-1):
16152 g(x, y) = input(2*x, 2*y)
16153 \endcode
16154 *
16155 * 'align' specifies how the loop iteration of each dimension of the
16156 * two stages being fused should be aligned in the fused loop nests
16157 * (see LoopAlignStrategy for options). Consider the following loop nests:
16158 \code
16159 for z = f_min_z to f_max_z:
16160 for y = f_min_y to f_max_y:
16161 for x = f_min_x to f_max_x:
16162 f(x, y, z) = x + y + z
16163 for z = g_min_z to g_max_z:
16164 for y = g_min_y to g_max_y:
16165 for x = g_min_x to g_max_x:
16166 g(x, y, z) = x - y - z
16167 \endcode
16168 *
16169 * If no alignment strategy is specified, the following loop nest will be
16170 * generated:
16171 \code
16172 for z = min(f_min_z, g_min_z) to max(f_max_z, g_max_z):
16173 for y = min(f_min_y, g_min_y) to max(f_max_y, g_max_y):
16174 for x = f_min_x to f_max_x:
16175 if (f_min_z <= z <= f_max_z):
16176 if (f_min_y <= y <= f_max_y):
16177 f(x, y, z) = x + y + z
16178 for x = g_min_x to g_max_x:
16179 if (g_min_z <= z <= g_max_z):
16180 if (g_min_y <= y <= g_max_y):
16181 g(x, y, z) = x - y - z
16182 \endcode
16183 *
16184 * Instead, these alignment strategies:
16185 \code
16186 g.compute_with(f, y, {{z, LoopAlignStrategy::AlignStart}, {y, LoopAlignStrategy::AlignEnd}});
16187 \endcode
16188 * will produce the following loop nest:
16189 \code
16190 f_loop_min_z = f_min_z
16191 f_loop_max_z = max(f_max_z, (f_min_z - g_min_z) + g_max_z)
16192 for z = f_min_z to f_loop_max_z:
16193 f_loop_min_y = min(f_min_y, (f_max_y - g_max_y) + g_min_y)
16194 f_loop_max_y = f_max_y
16195 for y = f_loop_min_y to f_loop_max_y:
16196 for x = f_min_x to f_max_x:
16197 if (f_loop_min_z <= z <= f_loop_max_z):
16198 if (f_loop_min_y <= y <= f_loop_max_y):
16199 f(x, y, z) = x + y + z
16200 for x = g_min_x to g_max_x:
16201 g_shift_z = g_min_z - f_loop_min_z
16202 g_shift_y = g_max_y - f_loop_max_y
16203 if (g_min_z <= (z + g_shift_z) <= g_max_z):
16204 if (g_min_y <= (y + g_shift_y) <= g_max_y):
16205 g(x, y + g_shift_y, z + g_shift_z) = x - (y + g_shift_y) - (z + g_shift_z)
16206 \endcode
16207 *
16208 * LoopAlignStrategy::AlignStart on dimension z will shift the loop iteration
16209 * of 'g' at dimension z so that its starting value matches that of 'f'.
16210 * Likewise, LoopAlignStrategy::AlignEnd on dimension y will shift the loop
16211 * iteration of 'g' at dimension y so that its end value matches that of 'f'.
16212 */
16213 // @{
16214 Stage &compute_with(LoopLevel loop_level, const std::vector<std::pair<VarOrRVar, LoopAlignStrategy>> &align);
16215 Stage &compute_with(LoopLevel loop_level, LoopAlignStrategy align = LoopAlignStrategy::Auto);
16216 Stage &compute_with(const Stage &s, const VarOrRVar &var, const std::vector<std::pair<VarOrRVar, LoopAlignStrategy>> &align);
16217 Stage &compute_with(const Stage &s, const VarOrRVar &var, LoopAlignStrategy align = LoopAlignStrategy::Auto);
16218 // @}
16219
16220 /** Scheduling calls that control how the domain of this stage is
16221 * traversed. See the documentation for Func for the meanings. */
16222 // @{
16223
16224 Stage &split(const VarOrRVar &old, const VarOrRVar &outer, const VarOrRVar &inner, const Expr &factor, TailStrategy tail = TailStrategy::Auto);
16225 Stage &fuse(const VarOrRVar &inner, const VarOrRVar &outer, const VarOrRVar &fused);
16226 Stage &serial(const VarOrRVar &var);
16227 Stage &parallel(const VarOrRVar &var);
16228 Stage &vectorize(const VarOrRVar &var);
16229 Stage &unroll(const VarOrRVar &var);
16230 Stage &parallel(const VarOrRVar &var, const Expr &task_size, TailStrategy tail = TailStrategy::Auto);
16231 Stage &vectorize(const VarOrRVar &var, const Expr &factor, TailStrategy tail = TailStrategy::Auto);
16232 Stage &unroll(const VarOrRVar &var, const Expr &factor, TailStrategy tail = TailStrategy::Auto);
16233 Stage &tile(const VarOrRVar &x, const VarOrRVar &y,
16234 const VarOrRVar &xo, const VarOrRVar &yo,
16235 const VarOrRVar &xi, const VarOrRVar &yi, const Expr &xfactor, const Expr &yfactor,
16236 TailStrategy tail = TailStrategy::Auto);
16237 Stage &tile(const VarOrRVar &x, const VarOrRVar &y,
16238 const VarOrRVar &xi, const VarOrRVar &yi,
16239 const Expr &xfactor, const Expr &yfactor,
16240 TailStrategy tail = TailStrategy::Auto);
16241 Stage &tile(const std::vector<VarOrRVar> &previous,
16242 const std::vector<VarOrRVar> &outers,
16243 const std::vector<VarOrRVar> &inners,
16244 const std::vector<Expr> &factors,
16245 const std::vector<TailStrategy> &tails);
16246 Stage &tile(const std::vector<VarOrRVar> &previous,
16247 const std::vector<VarOrRVar> &outers,
16248 const std::vector<VarOrRVar> &inners,
16249 const std::vector<Expr> &factors,
16250 TailStrategy tail = TailStrategy::Auto);
16251 Stage &tile(const std::vector<VarOrRVar> &previous,
16252 const std::vector<VarOrRVar> &inners,
16253 const std::vector<Expr> &factors,
16254 TailStrategy tail = TailStrategy::Auto);
16255 Stage &reorder(const std::vector<VarOrRVar> &vars);
16256
16257 template<typename... Args>
16258 HALIDE_NO_USER_CODE_INLINE typename std::enable_if<Internal::all_are_convertible<VarOrRVar, Args...>::value, Stage &>::type
16259 reorder(const VarOrRVar &x, const VarOrRVar &y, Args &&...args) {
16260 std::vector<VarOrRVar> collected_args{x, y, std::forward<Args>(args)...};
16261 return reorder(collected_args);
16262 }
16263
16264 Stage &rename(const VarOrRVar &old_name, const VarOrRVar &new_name);
16265 Stage specialize(const Expr &condition);
16266 void specialize_fail(const std::string &message);
16267
16268 Stage &gpu_threads(const VarOrRVar &thread_x, DeviceAPI device_api = DeviceAPI::Default_GPU);
16269 Stage &gpu_threads(const VarOrRVar &thread_x, const VarOrRVar &thread_y, DeviceAPI device_api = DeviceAPI::Default_GPU);
16270 Stage &gpu_threads(const VarOrRVar &thread_x, const VarOrRVar &thread_y, const VarOrRVar &thread_z, DeviceAPI device_api = DeviceAPI::Default_GPU);
16271
16272 Stage &gpu_lanes(const VarOrRVar &thread_x, DeviceAPI device_api = DeviceAPI::Default_GPU);
16273
16274 Stage &gpu_single_thread(DeviceAPI device_api = DeviceAPI::Default_GPU);
16275
16276 Stage &gpu_blocks(const VarOrRVar &block_x, DeviceAPI device_api = DeviceAPI::Default_GPU);
16277 Stage &gpu_blocks(const VarOrRVar &block_x, const VarOrRVar &block_y, DeviceAPI device_api = DeviceAPI::Default_GPU);
16278 Stage &gpu_blocks(const VarOrRVar &block_x, const VarOrRVar &block_y, const VarOrRVar &block_z, DeviceAPI device_api = DeviceAPI::Default_GPU);
16279
16280 Stage &gpu(const VarOrRVar &block_x, const VarOrRVar &thread_x, DeviceAPI device_api = DeviceAPI::Default_GPU);
16281 Stage &gpu(const VarOrRVar &block_x, const VarOrRVar &block_y,
16282 const VarOrRVar &thread_x, const VarOrRVar &thread_y,
16283 DeviceAPI device_api = DeviceAPI::Default_GPU);
16284 Stage &gpu(const VarOrRVar &block_x, const VarOrRVar &block_y, const VarOrRVar &block_z,
16285 const VarOrRVar &thread_x, const VarOrRVar &thread_y, const VarOrRVar &thread_z,
16286 DeviceAPI device_api = DeviceAPI::Default_GPU);
16287
16288 Stage &gpu_tile(const VarOrRVar &x, const VarOrRVar &bx, const VarOrRVar &tx, const Expr &x_size,
16289 TailStrategy tail = TailStrategy::Auto,
16290 DeviceAPI device_api = DeviceAPI::Default_GPU);
16291
16292 Stage &gpu_tile(const VarOrRVar &x, const VarOrRVar &tx, const Expr &x_size,
16293 TailStrategy tail = TailStrategy::Auto,
16294 DeviceAPI device_api = DeviceAPI::Default_GPU);
16295 Stage &gpu_tile(const VarOrRVar &x, const VarOrRVar &y,
16296 const VarOrRVar &bx, const VarOrRVar &by,
16297 const VarOrRVar &tx, const VarOrRVar &ty,
16298 const Expr &x_size, const Expr &y_size,
16299 TailStrategy tail = TailStrategy::Auto,
16300 DeviceAPI device_api = DeviceAPI::Default_GPU);
16301
16302 Stage &gpu_tile(const VarOrRVar &x, const VarOrRVar &y,
16303 const VarOrRVar &tx, const VarOrRVar &ty,
16304 const Expr &x_size, const Expr &y_size,
16305 TailStrategy tail = TailStrategy::Auto,
16306 DeviceAPI device_api = DeviceAPI::Default_GPU);
16307
16308 Stage &gpu_tile(const VarOrRVar &x, const VarOrRVar &y, const VarOrRVar &z,
16309 const VarOrRVar &bx, const VarOrRVar &by, const VarOrRVar &bz,
16310 const VarOrRVar &tx, const VarOrRVar &ty, const VarOrRVar &tz,
16311 const Expr &x_size, const Expr &y_size, const Expr &z_size,
16312 TailStrategy tail = TailStrategy::Auto,
16313 DeviceAPI device_api = DeviceAPI::Default_GPU);
16314 Stage &gpu_tile(const VarOrRVar &x, const VarOrRVar &y, const VarOrRVar &z,
16315 const VarOrRVar &tx, const VarOrRVar &ty, const VarOrRVar &tz,
16316 const Expr &x_size, const Expr &y_size, const Expr &z_size,
16317 TailStrategy tail = TailStrategy::Auto,
16318 DeviceAPI device_api = DeviceAPI::Default_GPU);
16319
16320 Stage &allow_race_conditions();
16321 Stage &atomic(bool override_associativity_test = false);
16322
16323 Stage &hexagon(const VarOrRVar &x = Var::outermost());
16324 Stage &prefetch(const Func &f, const VarOrRVar &var, Expr offset = 1,
16325 PrefetchBoundStrategy strategy = PrefetchBoundStrategy::GuardWithIf);
16326 Stage &prefetch(const Internal::Parameter &param, const VarOrRVar &var, Expr offset = 1,
16327 PrefetchBoundStrategy strategy = PrefetchBoundStrategy::GuardWithIf);
16328 template<typename T>
16329 Stage &prefetch(const T &image, VarOrRVar var, Expr offset = 1,
16330 PrefetchBoundStrategy strategy = PrefetchBoundStrategy::GuardWithIf) {
16331 return prefetch(image.parameter(), var, offset, strategy);
16332 }
16333 // @}
16334
16335 /** Attempt to get the source file and line where this stage was
16336 * defined by parsing the process's own debug symbols. Returns an
16337 * empty string if no debug symbols were found or the debug
16338 * symbols were not understood. Works on OS X and Linux only. */
16339 std::string source_location() const;
16340};
16341
16342// For backwards compatibility, keep the ScheduleHandle name.
16343typedef Stage ScheduleHandle;
16344
16345class FuncTupleElementRef;
16346
16347/** A fragment of front-end syntax of the form f(x, y, z), where x, y,
16348 * z are Vars or Exprs. If could be the left hand side of a definition or
16349 * an update definition, or it could be a call to a function. We don't know
16350 * until we see how this object gets used.
16351 */
16352class FuncRef {
16353 Internal::Function func;
16354 int implicit_placeholder_pos;
16355 int implicit_count;
16356 std::vector<Expr> args;
16357 std::vector<Expr> args_with_implicit_vars(const std::vector<Expr> &e) const;
16358
16359 /** Helper for function update by Tuple. If the function does not
16360 * already have a pure definition, init_val will be used as RHS of
16361 * each tuple element in the initial function definition. */
16362 template<typename BinaryOp>
16363 Stage func_ref_update(const Tuple &e, int init_val);
16364
16365 /** Helper for function update by Expr. If the function does not
16366 * already have a pure definition, init_val will be used as RHS in
16367 * the initial function definition. */
16368 template<typename BinaryOp>
16369 Stage func_ref_update(Expr e, int init_val);
16370
16371public:
16372 FuncRef(const Internal::Function &, const std::vector<Expr> &,
16373 int placeholder_pos = -1, int count = 0);
16374 FuncRef(Internal::Function, const std::vector<Var> &,
16375 int placeholder_pos = -1, int count = 0);
16376
16377 /** Use this as the left-hand-side of a definition or an update definition
16378 * (see \ref RDom).
16379 */
16380 Stage operator=(const Expr &);
16381
16382 /** Use this as the left-hand-side of a definition or an update definition
16383 * for a Func with multiple outputs. */
16384 Stage operator=(const Tuple &);
16385
16386 /** Define a stage that adds the given expression to this Func. If the
16387 * expression refers to some RDom, this performs a sum reduction of the
16388 * expression over the domain. If the function does not already have a
16389 * pure definition, this sets it to zero.
16390 */
16391 // @{
16392 Stage operator+=(Expr);
16393 Stage operator+=(const Tuple &);
16394 Stage operator+=(const FuncRef &);
16395 // @}
16396
16397 /** Define a stage that adds the negative of the given expression to this
16398 * Func. If the expression refers to some RDom, this performs a sum reduction
16399 * of the negative of the expression over the domain. If the function does
16400 * not already have a pure definition, this sets it to zero.
16401 */
16402 // @{
16403 Stage operator-=(Expr);
16404 Stage operator-=(const Tuple &);
16405 Stage operator-=(const FuncRef &);
16406 // @}
16407
16408 /** Define a stage that multiplies this Func by the given expression. If the
16409 * expression refers to some RDom, this performs a product reduction of the
16410 * expression over the domain. If the function does not already have a pure
16411 * definition, this sets it to 1.
16412 */
16413 // @{
16414 Stage operator*=(Expr);
16415 Stage operator*=(const Tuple &);
16416 Stage operator*=(const FuncRef &);
16417 // @}
16418
16419 /** Define a stage that divides this Func by the given expression.
16420 * If the expression refers to some RDom, this performs a product
16421 * reduction of the inverse of the expression over the domain. If the
16422 * function does not already have a pure definition, this sets it to 1.
16423 */
16424 // @{
16425 Stage operator/=(Expr);
16426 Stage operator/=(const Tuple &);
16427 Stage operator/=(const FuncRef &);
16428 // @}
16429
16430 /* Override the usual assignment operator, so that
16431 * f(x, y) = g(x, y) defines f.
16432 */
16433 Stage operator=(const FuncRef &);
16434
16435 /** Use this as a call to the function, and not the left-hand-side
16436 * of a definition. Only works for single-output Funcs. */
16437 operator Expr() const;
16438
16439 /** When a FuncRef refers to a function that provides multiple
16440 * outputs, you can access each output as an Expr using
16441 * operator[].
16442 */
16443 FuncTupleElementRef operator[](int) const;
16444
16445 /** How many outputs does the function this refers to produce. */
16446 size_t size() const;
16447
16448 /** What function is this calling? */
16449 Internal::Function function() const {
16450 return func;
16451 }
16452};
16453
16454/** Explicit overloads of min and max for FuncRef. These exist to
16455 * disambiguate calls to min on FuncRefs when a user has pulled both
16456 * Halide::min and std::min into their namespace. */
16457// @{
16458inline Expr min(const FuncRef &a, const FuncRef &b) {
16459 return min(Expr(a), Expr(b));
16460}
16461inline Expr max(const FuncRef &a, const FuncRef &b) {
16462 return max(Expr(a), Expr(b));
16463}
16464// @}
16465
16466/** A fragment of front-end syntax of the form f(x, y, z)[index], where x, y,
16467 * z are Vars or Exprs. If could be the left hand side of an update
16468 * definition, or it could be a call to a function. We don't know
16469 * until we see how this object gets used.
16470 */
16471class FuncTupleElementRef {
16472 FuncRef func_ref;
16473 std::vector<Expr> args; // args to the function
16474 int idx; // Index to function outputs
16475
16476 /** Helper function that generates a Tuple where element at 'idx' is set
16477 * to 'e' and the rests are undef. */
16478 Tuple values_with_undefs(const Expr &e) const;
16479
16480public:
16481 FuncTupleElementRef(const FuncRef &ref, const std::vector<Expr> &args, int idx);
16482
16483 /** Use this as the left-hand-side of an update definition of Tuple
16484 * component 'idx' of a Func (see \ref RDom). The function must
16485 * already have an initial definition.
16486 */
16487 Stage operator=(const Expr &e);
16488
16489 /** Define a stage that adds the given expression to Tuple component 'idx'
16490 * of this Func. The other Tuple components are unchanged. If the expression
16491 * refers to some RDom, this performs a sum reduction of the expression over
16492 * the domain. The function must already have an initial definition.
16493 */
16494 Stage operator+=(const Expr &e);
16495
16496 /** Define a stage that adds the negative of the given expression to Tuple
16497 * component 'idx' of this Func. The other Tuple components are unchanged.
16498 * If the expression refers to some RDom, this performs a sum reduction of
16499 * the negative of the expression over the domain. The function must already
16500 * have an initial definition.
16501 */
16502 Stage operator-=(const Expr &e);
16503
16504 /** Define a stage that multiplies Tuple component 'idx' of this Func by
16505 * the given expression. The other Tuple components are unchanged. If the
16506 * expression refers to some RDom, this performs a product reduction of
16507 * the expression over the domain. The function must already have an
16508 * initial definition.
16509 */
16510 Stage operator*=(const Expr &e);
16511
16512 /** Define a stage that divides Tuple component 'idx' of this Func by
16513 * the given expression. The other Tuple components are unchanged.
16514 * If the expression refers to some RDom, this performs a product
16515 * reduction of the inverse of the expression over the domain. The function
16516 * must already have an initial definition.
16517 */
16518 Stage operator/=(const Expr &e);
16519
16520 /* Override the usual assignment operator, so that
16521 * f(x, y)[index] = g(x, y) defines f.
16522 */
16523 Stage operator=(const FuncRef &e);
16524
16525 /** Use this as a call to Tuple component 'idx' of a Func, and not the
16526 * left-hand-side of a definition. */
16527 operator Expr() const;
16528
16529 /** What function is this calling? */
16530 Internal::Function function() const {
16531 return func_ref.function();
16532 }
16533
16534 /** Return index to the function outputs. */
16535 int index() const {
16536 return idx;
16537 }
16538};
16539
16540namespace Internal {
16541class IRMutator;
16542} // namespace Internal
16543
16544/** Helper class for identifying purpose of an Expr passed to memoize.
16545 */
16546class EvictionKey {
16547protected:
16548 Expr key;
16549 friend class Func;
16550
16551public:
16552 explicit EvictionKey(const Expr &expr = Expr())
16553 : key(expr) {
16554 }
16555};
16556
16557/** A halide function. This class represents one stage in a Halide
16558 * pipeline, and is the unit by which we schedule things. By default
16559 * they are aggressively inlined, so you are encouraged to make lots
16560 * of little functions, rather than storing things in Exprs. */
16561class Func {
16562
16563 /** A handle on the internal halide function that this
16564 * represents */
16565 Internal::Function func;
16566
16567 /** When you make a reference to this function with fewer
16568 * arguments than it has dimensions, the argument list is bulked
16569 * up with 'implicit' vars with canonical names. This lets you
16570 * pass around partially applied Halide functions. */
16571 // @{
16572 std::pair<int, int> add_implicit_vars(std::vector<Var> &) const;
16573 std::pair<int, int> add_implicit_vars(std::vector<Expr> &) const;
16574 // @}
16575
16576 /** The imaging pipeline that outputs this Func alone. */
16577 Pipeline pipeline_;
16578
16579 /** Get the imaging pipeline that outputs this Func alone,
16580 * creating it (and freezing the Func) if necessary. */
16581 Pipeline pipeline();
16582
16583 // Helper function for recursive reordering support
16584 Func &reorder_storage(const std::vector<Var> &dims, size_t start);
16585
16586 void invalidate_cache();
16587
16588public:
16589 /** Declare a new undefined function with the given name */
16590 explicit Func(const std::string &name);
16591
16592 /** Declare a new undefined function with an
16593 * automatically-generated unique name */
16594 Func();
16595
16596 /** Declare a new function with an automatically-generated unique
16597 * name, and define it to return the given expression (which may
16598 * not contain free variables). */
16599 explicit Func(const Expr &e);
16600
16601 /** Construct a new Func to wrap an existing, already-define
16602 * Function object. */
16603 explicit Func(Internal::Function f);
16604
16605 /** Construct a new Func to wrap a Buffer. */
16606 template<typename T>
16607 HALIDE_NO_USER_CODE_INLINE explicit Func(Buffer<T> &im)
16608 : Func() {
16609 (*this)(_) = im(_);
16610 }
16611
16612 /** Evaluate this function over some rectangular domain and return
16613 * the resulting buffer or buffers. Performs compilation if the
16614 * Func has not previously been realized and compile_jit has not
16615 * been called. If the final stage of the pipeline is on the GPU,
16616 * data is copied back to the host before being returned. The
16617 * returned Realization should probably be instantly converted to
16618 * a Buffer class of the appropriate type. That is, do this:
16619 *
16620 \code
16621 f(x) = sin(x);
16622 Buffer<float> im = f.realize(...);
16623 \endcode
16624 *
16625 * If your Func has multiple values, because you defined it using
16626 * a Tuple, then casting the result of a realize call to a buffer
16627 * or image will produce a run-time error. Instead you should do the
16628 * following:
16629 *
16630 \code
16631 f(x) = Tuple(x, sin(x));
16632 Realization r = f.realize(...);
16633 Buffer<int> im0 = r[0];
16634 Buffer<float> im1 = r[1];
16635 \endcode
16636 *
16637 * In Halide formal arguments of a computation are specified using
16638 * Param<T> and ImageParam objects in the expressions defining the
16639 * computation. The param_map argument to realize allows
16640 * specifying a set of per-call parameters to be used for a
16641 * specific computation. This method is thread-safe where the
16642 * globals used by Param<T> and ImageParam are not. Any parameters
16643 * that are not in the param_map are taken from the global values,
16644 * so those can continue to be used if they are not changing
16645 * per-thread.
16646 *
16647 * One can explicitly construct a ParamMap and
16648 * use its set method to insert Parameter to scalar or Buffer
16649 * value mappings:
16650 *
16651 \code
16652 Param<int32> p(42);
16653 ImageParam img(Int(32), 1);
16654 f(x) = img(x) + p;
16655
16656 Buffer<int32_t) arg_img(10, 10);
16657 <fill in arg_img...>
16658 ParamMap params;
16659 params.set(p, 17);
16660 params.set(img, arg_img);
16661
16662 Target t = get_jit_target_from_environment();
16663 Buffer<int32_t> result = f.realize({10, 10}, t, params);
16664 \endcode
16665 *
16666 * Alternatively, an initializer list can be used
16667 * directly in the realize call to pass this information:
16668 *
16669 \code
16670 Param<int32> p(42);
16671 ImageParam img(Int(32), 1);
16672 f(x) = img(x) + p;
16673
16674 Buffer<int32_t) arg_img(10, 10);
16675 <fill in arg_img...>
16676
16677 Target t = get_jit_target_from_environment();
16678 Buffer<int32_t> result = f.realize({10, 10}, t, { { p, 17 }, { img, arg_img } });
16679 \endcode
16680 *
16681 * If the Func cannot be realized into a buffer of the given size
16682 * due to scheduling constraints on scattering update definitions,
16683 * it will be realized into a larger buffer of the minimum size
16684 * possible, and a cropped view at the requested size will be
16685 * returned. It is thus not safe to assume the returned buffers
16686 * are contiguous in memory. This behavior can be disabled with
16687 * the NoBoundsQuery target flag, in which case an error about
16688 * writing out of bounds on the output buffer will trigger
16689 * instead.
16690 *
16691 */
16692 // @{
16693 Realization realize(std::vector<int32_t> sizes = {}, const Target &target = Target(),
16694 const ParamMap &param_map = ParamMap::empty_map());
16695 HALIDE_ATTRIBUTE_DEPRECATED("Call realize() with a vector<int> instead")
16696 Realization realize(int x_size, int y_size, int z_size, int w_size, const Target &target = Target(),
16697 const ParamMap &param_map = ParamMap::empty_map());
16698 HALIDE_ATTRIBUTE_DEPRECATED("Call realize() with a vector<int> instead")
16699 Realization realize(int x_size, int y_size, int z_size, const Target &target = Target(),
16700 const ParamMap &param_map = ParamMap::empty_map());
16701 HALIDE_ATTRIBUTE_DEPRECATED("Call realize() with a vector<int> instead")
16702 Realization realize(int x_size, int y_size, const Target &target = Target(),
16703 const ParamMap &param_map = ParamMap::empty_map());
16704
16705 // Making this a template function is a trick: `{intliteral}` is a valid scalar initializer
16706 // in C++, but we want it to match the vector call, not the (deprecated) scalar one.
16707 template<typename T, typename = typename std::enable_if<std::is_same<T, int>::value>::type>
16708 HALIDE_ATTRIBUTE_DEPRECATED("Call realize() with a vector<int> instead")
16709 HALIDE_ALWAYS_INLINE Realization realize(T x_size, const Target &target = Target(),
16710 const ParamMap &param_map = ParamMap::empty_map()) {
16711 return realize(std::vector<int32_t>{x_size}, target, param_map);
16712 }
16713 // @}
16714
16715 /** Evaluate this function into an existing allocated buffer or
16716 * buffers. If the buffer is also one of the arguments to the
16717 * function, strange things may happen, as the pipeline isn't
16718 * necessarily safe to run in-place. If you pass multiple buffers,
16719 * they must have matching sizes. This form of realize does *not*
16720 * automatically copy data back from the GPU. */
16721 void realize(Pipeline::RealizationArg outputs, const Target &target = Target(),
16722 const ParamMap &param_map = ParamMap::empty_map());
16723
16724 /** For a given size of output, or a given output buffer,
16725 * determine the bounds required of all unbound ImageParams
16726 * referenced. Communicates the result by allocating new buffers
16727 * of the appropriate size and binding them to the unbound
16728 * ImageParams.
16729 *
16730 * Set the documentation for Func::realize regarding the
16731 * ParamMap. There is one difference in that input Buffer<>
16732 * arguments that are being inferred are specified as a pointer to
16733 * the Buffer<> in the ParamMap. E.g.
16734 *
16735 \code
16736 Param<int32> p(42);
16737 ImageParam img(Int(32), 1);
16738 f(x) = img(x) + p;
16739
16740 Target t = get_jit_target_from_environment();
16741 Buffer<> in;
16742 f.infer_input_bounds({10, 10}, t, { { img, &in } });
16743 \endcode
16744 * On return, in will be an allocated buffer of the correct size
16745 * to evaulate f over a 10x10 region.
16746 */
16747 // @{
16748 void infer_input_bounds(const std::vector<int32_t> &sizes,
16749 const Target &target = get_jit_target_from_environment(),
16750 const ParamMap &param_map = ParamMap::empty_map());
16751 void infer_input_bounds(Pipeline::RealizationArg outputs,
16752 const Target &target = get_jit_target_from_environment(),
16753 const ParamMap &param_map = ParamMap::empty_map());
16754 // @}
16755
16756 /** Statically compile this function to llvm bitcode, with the
16757 * given filename (which should probably end in .bc), type
16758 * signature, and C function name (which defaults to the same name
16759 * as this halide function */
16760 //@{
16761 void compile_to_bitcode(const std::string &filename, const std::vector<Argument> &, const std::string &fn_name,
16762 const Target &target = get_target_from_environment());
16763 void compile_to_bitcode(const std::string &filename, const std::vector<Argument> &,
16764 const Target &target = get_target_from_environment());
16765 // @}
16766
16767 /** Statically compile this function to llvm assembly, with the
16768 * given filename (which should probably end in .ll), type
16769 * signature, and C function name (which defaults to the same name
16770 * as this halide function */
16771 //@{
16772 void compile_to_llvm_assembly(const std::string &filename, const std::vector<Argument> &, const std::string &fn_name,
16773 const Target &target = get_target_from_environment());
16774 void compile_to_llvm_assembly(const std::string &filename, const std::vector<Argument> &,
16775 const Target &target = get_target_from_environment());
16776 // @}
16777
16778 /** Statically compile this function to an object file, with the
16779 * given filename (which should probably end in .o or .obj), type
16780 * signature, and C function name (which defaults to the same name
16781 * as this halide function. You probably don't want to use this
16782 * directly; call compile_to_static_library or compile_to_file instead. */
16783 //@{
16784 void compile_to_object(const std::string &filename, const std::vector<Argument> &, const std::string &fn_name,
16785 const Target &target = get_target_from_environment());
16786 void compile_to_object(const std::string &filename, const std::vector<Argument> &,
16787 const Target &target = get_target_from_environment());
16788 // @}
16789
16790 /** Emit a header file with the given filename for this
16791 * function. The header will define a function with the type
16792 * signature given by the second argument, and a name given by the
16793 * third. The name defaults to the same name as this halide
16794 * function. You don't actually have to have defined this function
16795 * yet to call this. You probably don't want to use this directly;
16796 * call compile_to_static_library or compile_to_file instead. */
16797 void compile_to_header(const std::string &filename, const std::vector<Argument> &, const std::string &fn_name = "",
16798 const Target &target = get_target_from_environment());
16799
16800 /** Statically compile this function to text assembly equivalent
16801 * to the object file generated by compile_to_object. This is
16802 * useful for checking what Halide is producing without having to
16803 * disassemble anything, or if you need to feed the assembly into
16804 * some custom toolchain to produce an object file (e.g. iOS) */
16805 //@{
16806 void compile_to_assembly(const std::string &filename, const std::vector<Argument> &, const std::string &fn_name,
16807 const Target &target = get_target_from_environment());
16808 void compile_to_assembly(const std::string &filename, const std::vector<Argument> &,
16809 const Target &target = get_target_from_environment());
16810 // @}
16811
16812 /** Statically compile this function to C source code. This is
16813 * useful for providing fallback code paths that will compile on
16814 * many platforms. Vectorization will fail, and parallelization
16815 * will produce serial code. */
16816 void compile_to_c(const std::string &filename,
16817 const std::vector<Argument> &,
16818 const std::string &fn_name = "",
16819 const Target &target = get_target_from_environment());
16820
16821 /** Write out an internal representation of lowered code. Useful
16822 * for analyzing and debugging scheduling. Can emit html or plain
16823 * text. */
16824 void compile_to_lowered_stmt(const std::string &filename,
16825 const std::vector<Argument> &args,
16826 StmtOutputFormat fmt = Text,
16827 const Target &target = get_target_from_environment());
16828
16829 /** Write out the loop nests specified by the schedule for this
16830 * Function. Helpful for understanding what a schedule is
16831 * doing. */
16832 void print_loop_nest();
16833
16834 /** Compile to object file and header pair, with the given
16835 * arguments. The name defaults to the same name as this halide
16836 * function.
16837 */
16838 void compile_to_file(const std::string &filename_prefix, const std::vector<Argument> &args,
16839 const std::string &fn_name = "",
16840 const Target &target = get_target_from_environment());
16841
16842 /** Compile to static-library file and header pair, with the given
16843 * arguments. The name defaults to the same name as this halide
16844 * function.
16845 */
16846 void compile_to_static_library(const std::string &filename_prefix, const std::vector<Argument> &args,
16847 const std::string &fn_name = "",
16848 const Target &target = get_target_from_environment());
16849
16850 /** Compile to static-library file and header pair once for each target;
16851 * each resulting function will be considered (in order) via halide_can_use_target_features()
16852 * at runtime, with the first appropriate match being selected for subsequent use.
16853 * This is typically useful for specializations that may vary unpredictably by machine
16854 * (e.g., SSE4.1/AVX/AVX2 on x86 desktop machines).
16855 * All targets must have identical arch-os-bits.
16856 */
16857 void compile_to_multitarget_static_library(const std::string &filename_prefix,
16858 const std::vector<Argument> &args,
16859 const std::vector<Target> &targets);
16860
16861 /** Like compile_to_multitarget_static_library(), except that the object files
16862 * are all output as object files (rather than bundled into a static library).
16863 *
16864 * `suffixes` is an optional list of strings to use for as the suffix for each object
16865 * file. If nonempty, it must be the same length as `targets`. (If empty, Target::to_string()
16866 * will be used for each suffix.)
16867 *
16868 * Note that if `targets.size()` > 1, the wrapper code (to select the subtarget)
16869 * will be generated with the filename `${filename_prefix}_wrapper.o`
16870 *
16871 * Note that if `targets.size()` > 1 and `no_runtime` is not specified, the runtime
16872 * will be generated with the filename `${filename_prefix}_runtime.o`
16873 */
16874 void compile_to_multitarget_object_files(const std::string &filename_prefix,
16875 const std::vector<Argument> &args,
16876 const std::vector<Target> &targets,
16877 const std::vector<std::string> &suffixes);
16878
16879 /** Store an internal representation of lowered code as a self
16880 * contained Module suitable for further compilation. */
16881 Module compile_to_module(const std::vector<Argument> &args, const std::string &fn_name = "",
16882 const Target &target = get_target_from_environment());
16883
16884 /** Compile and generate multiple target files with single call.
16885 * Deduces target files based on filenames specified in
16886 * output_files map.
16887 */
16888 void compile_to(const std::map<Output, std::string> &output_files,
16889 const std::vector<Argument> &args,
16890 const std::string &fn_name,
16891 const Target &target = get_target_from_environment());
16892
16893 /** Eagerly jit compile the function to machine code. This
16894 * normally happens on the first call to realize. If you're
16895 * running your halide pipeline inside time-sensitive code and
16896 * wish to avoid including the time taken to compile a pipeline,
16897 * then you can call this ahead of time. Default is to use the Target
16898 * returned from Halide::get_jit_target_from_environment()
16899 */
16900 void compile_jit(const Target &target = get_jit_target_from_environment());
16901
16902 /** Set the error handler function that be called in the case of
16903 * runtime errors during halide pipelines. If you are compiling
16904 * statically, you can also just define your own function with
16905 * signature
16906 \code
16907 extern "C" void halide_error(void *user_context, const char *);
16908 \endcode
16909 * This will clobber Halide's version.
16910 */
16911 void set_error_handler(void (*handler)(void *, const char *));
16912
16913 /** Set a custom malloc and free for halide to use. Malloc should
16914 * return 32-byte aligned chunks of memory, and it should be safe
16915 * for Halide to read slightly out of bounds (up to 8 bytes before
16916 * the start or beyond the end). If compiling statically, routines
16917 * with appropriate signatures can be provided directly
16918 \code
16919 extern "C" void *halide_malloc(void *, size_t)
16920 extern "C" void halide_free(void *, void *)
16921 \endcode
16922 * These will clobber Halide's versions. See HalideRuntime.h
16923 * for declarations.
16924 */
16925 void set_custom_allocator(void *(*malloc)(void *, size_t),
16926 void (*free)(void *, void *));
16927
16928 /** Set a custom task handler to be called by the parallel for
16929 * loop. It is useful to set this if you want to do some
16930 * additional bookkeeping at the granularity of parallel
16931 * tasks. The default implementation does this:
16932 \code
16933 extern "C" int halide_do_task(void *user_context,
16934 int (*f)(void *, int, uint8_t *),
16935 int idx, uint8_t *state) {
16936 return f(user_context, idx, state);
16937 }
16938 \endcode
16939 * If you are statically compiling, you can also just define your
16940 * own version of the above function, and it will clobber Halide's
16941 * version.
16942 *
16943 * If you're trying to use a custom parallel runtime, you probably
16944 * don't want to call this. See instead \ref Func::set_custom_do_par_for .
16945 */
16946 void set_custom_do_task(
16947 int (*custom_do_task)(void *, int (*)(void *, int, uint8_t *),
16948 int, uint8_t *));
16949
16950 /** Set a custom parallel for loop launcher. Useful if your app
16951 * already manages a thread pool. The default implementation is
16952 * equivalent to this:
16953 \code
16954 extern "C" int halide_do_par_for(void *user_context,
16955 int (*f)(void *, int, uint8_t *),
16956 int min, int extent, uint8_t *state) {
16957 int exit_status = 0;
16958 parallel for (int idx = min; idx < min+extent; idx++) {
16959 int job_status = halide_do_task(user_context, f, idx, state);
16960 if (job_status) exit_status = job_status;
16961 }
16962 return exit_status;
16963 }
16964 \endcode
16965 *
16966 * However, notwithstanding the above example code, if one task
16967 * fails, we may skip over other tasks, and if two tasks return
16968 * different error codes, we may select one arbitrarily to return.
16969 *
16970 * If you are statically compiling, you can also just define your
16971 * own version of the above function, and it will clobber Halide's
16972 * version.
16973 */
16974 void set_custom_do_par_for(
16975 int (*custom_do_par_for)(void *, int (*)(void *, int, uint8_t *), int,
16976 int, uint8_t *));
16977
16978 /** Set custom routines to call when tracing is enabled. Call this
16979 * on the output Func of your pipeline. This then sets custom
16980 * routines for the entire pipeline, not just calls to this
16981 * Func.
16982 *
16983 * If you are statically compiling, you can also just define your
16984 * own versions of the tracing functions (see HalideRuntime.h),
16985 * and they will clobber Halide's versions. */
16986 void set_custom_trace(int (*trace_fn)(void *, const halide_trace_event_t *));
16987
16988 /** Set the function called to print messages from the runtime.
16989 * If you are compiling statically, you can also just define your
16990 * own function with signature
16991 \code
16992 extern "C" void halide_print(void *user_context, const char *);
16993 \endcode
16994 * This will clobber Halide's version.
16995 */
16996 void set_custom_print(void (*handler)(void *, const char *));
16997
16998 /** Get a struct containing the currently set custom functions
16999 * used by JIT. */
17000 const Internal::JITHandlers &jit_handlers();
17001
17002 /** Add a custom pass to be used during lowering. It is run after
17003 * all other lowering passes. Can be used to verify properties of
17004 * the lowered Stmt, instrument it with extra code, or otherwise
17005 * modify it. The Func takes ownership of the pass, and will call
17006 * delete on it when the Func goes out of scope. So don't pass a
17007 * stack object, or share pass instances between multiple
17008 * Funcs. */
17009 template<typename T>
17010 void add_custom_lowering_pass(T *pass) {
17011 // Template instantiate a custom deleter for this type, then
17012 // wrap in a lambda. The custom deleter lives in user code, so
17013 // that deletion is on the same heap as construction (I hate Windows).
17014 add_custom_lowering_pass(pass, [pass]() { delete_lowering_pass<T>(pass); });
17015 }
17016
17017 /** Add a custom pass to be used during lowering, with the
17018 * function that will be called to delete it also passed in. Set
17019 * it to nullptr if you wish to retain ownership of the object. */
17020 void add_custom_lowering_pass(Internal::IRMutator *pass, std::function<void()> deleter);
17021
17022 /** Remove all previously-set custom lowering passes */
17023 void clear_custom_lowering_passes();
17024
17025 /** Get the custom lowering passes. */
17026 const std::vector<CustomLoweringPass> &custom_lowering_passes();
17027
17028 /** When this function is compiled, include code that dumps its
17029 * values to a file after it is realized, for the purpose of
17030 * debugging.
17031 *
17032 * If filename ends in ".tif" or ".tiff" (case insensitive) the file
17033 * is in TIFF format and can be read by standard tools. Oherwise, the
17034 * file format is as follows:
17035 *
17036 * All data is in the byte-order of the target platform. First, a
17037 * 20 byte-header containing four 32-bit ints, giving the extents
17038 * of the first four dimensions. Dimensions beyond four are
17039 * folded into the fourth. Then, a fifth 32-bit int giving the
17040 * data type of the function. The typecodes are given by: float =
17041 * 0, double = 1, uint8_t = 2, int8_t = 3, uint16_t = 4, int16_t =
17042 * 5, uint32_t = 6, int32_t = 7, uint64_t = 8, int64_t = 9. The
17043 * data follows the header, as a densely packed array of the given
17044 * size and the given type. If given the extension .tmp, this file
17045 * format can be natively read by the program ImageStack. */
17046 void debug_to_file(const std::string &filename);
17047
17048 /** The name of this function, either given during construction,
17049 * or automatically generated. */
17050 const std::string &name() const;
17051
17052 /** Get the pure arguments. */
17053 std::vector<Var> args() const;
17054
17055 /** The right-hand-side value of the pure definition of this
17056 * function. Causes an error if there's no pure definition, or if
17057 * the function is defined to return multiple values. */
17058 Expr value() const;
17059
17060 /** The values returned by this function. An error if the function
17061 * has not been been defined. Returns a Tuple with one element for
17062 * functions defined to return a single value. */
17063 Tuple values() const;
17064
17065 /** Does this function have at least a pure definition. */
17066 bool defined() const;
17067
17068 /** Get the left-hand-side of the update definition. An empty
17069 * vector if there's no update definition. If there are
17070 * multiple update definitions for this function, use the
17071 * argument to select which one you want. */
17072 const std::vector<Expr> &update_args(int idx = 0) const;
17073
17074 /** Get the right-hand-side of an update definition. An error if
17075 * there's no update definition. If there are multiple
17076 * update definitions for this function, use the argument to
17077 * select which one you want. */
17078 Expr update_value(int idx = 0) const;
17079
17080 /** Get the right-hand-side of an update definition for
17081 * functions that returns multiple values. An error if there's no
17082 * update definition. Returns a Tuple with one element for
17083 * functions that return a single value. */
17084 Tuple update_values(int idx = 0) const;
17085
17086 /** Get the RVars of the reduction domain for an update definition, if there is
17087 * one. */
17088 std::vector<RVar> rvars(int idx = 0) const;
17089
17090 /** Does this function have at least one update definition? */
17091 bool has_update_definition() const;
17092
17093 /** How many update definitions does this function have? */
17094 int num_update_definitions() const;
17095
17096 /** Is this function an external stage? That is, was it defined
17097 * using define_extern? */
17098 bool is_extern() const;
17099
17100 /** Add an extern definition for this Func. This lets you define a
17101 * Func that represents an external pipeline stage. You can, for
17102 * example, use it to wrap a call to an extern library such as
17103 * fftw. */
17104 // @{
17105 void define_extern(const std::string &function_name,
17106 const std::vector<ExternFuncArgument> &params, Type t,
17107 int dimensionality,
17108 NameMangling mangling = NameMangling::Default,
17109 DeviceAPI device_api = DeviceAPI::Host) {
17110 define_extern(function_name, params, t,
17111 Internal::make_argument_list(dimensionality), mangling,
17112 device_api);
17113 }
17114
17115 void define_extern(const std::string &function_name,
17116 const std::vector<ExternFuncArgument> &params,
17117 const std::vector<Type> &types, int dimensionality,
17118 NameMangling mangling) {
17119 define_extern(function_name, params, types,
17120 Internal::make_argument_list(dimensionality), mangling);
17121 }
17122
17123 void define_extern(const std::string &function_name,
17124 const std::vector<ExternFuncArgument> &params,
17125 const std::vector<Type> &types, int dimensionality,
17126 NameMangling mangling = NameMangling::Default,
17127 DeviceAPI device_api = DeviceAPI::Host) {
17128 define_extern(function_name, params, types,
17129 Internal::make_argument_list(dimensionality), mangling,
17130 device_api);
17131 }
17132
17133 void define_extern(const std::string &function_name,
17134 const std::vector<ExternFuncArgument> &params, Type t,
17135 const std::vector<Var> &arguments,
17136 NameMangling mangling = NameMangling::Default,
17137 DeviceAPI device_api = DeviceAPI::Host) {
17138 define_extern(function_name, params, std::vector<Type>{t}, arguments,
17139 mangling, device_api);
17140 }
17141
17142 void define_extern(const std::string &function_name,
17143 const std::vector<ExternFuncArgument> &params,
17144 const std::vector<Type> &types,
17145 const std::vector<Var> &arguments,
17146 NameMangling mangling = NameMangling::Default,
17147 DeviceAPI device_api = DeviceAPI::Host);
17148 // @}
17149
17150 /** Get the types of the outputs of this Func. */
17151 const std::vector<Type> &output_types() const;
17152
17153 /** Get the number of outputs of this Func. Corresponds to the
17154 * size of the Tuple this Func was defined to return. */
17155 int outputs() const;
17156
17157 /** Get the name of the extern function called for an extern
17158 * definition. */
17159 const std::string &extern_function_name() const;
17160
17161 /** The dimensionality (number of arguments) of this
17162 * function. Zero if the function is not yet defined. */
17163 int dimensions() const;
17164
17165 /** Construct either the left-hand-side of a definition, or a call
17166 * to a functions that happens to only contain vars as
17167 * arguments. If the function has already been defined, and fewer
17168 * arguments are given than the function has dimensions, then
17169 * enough implicit vars are added to the end of the argument list
17170 * to make up the difference (see \ref Var::implicit) */
17171 // @{
17172 FuncRef operator()(std::vector<Var>) const;
17173
17174 template<typename... Args>
17175 HALIDE_NO_USER_CODE_INLINE typename std::enable_if<Internal::all_are_convertible<Var, Args...>::value, FuncRef>::type
17176 operator()(Args &&...args) const {
17177 std::vector<Var> collected_args{std::forward<Args>(args)...};
17178 return this->operator()(collected_args);
17179 }
17180 // @}
17181
17182 /** Either calls to the function, or the left-hand-side of
17183 * an update definition (see \ref RDom). If the function has
17184 * already been defined, and fewer arguments are given than the
17185 * function has dimensions, then enough implicit vars are added to
17186 * the end of the argument list to make up the difference. (see
17187 * \ref Var::implicit)*/
17188 // @{
17189 FuncRef operator()(std::vector<Expr>) const;
17190
17191 template<typename... Args>
17192 HALIDE_NO_USER_CODE_INLINE typename std::enable_if<Internal::all_are_convertible<Expr, Args...>::value, FuncRef>::type
17193 operator()(const Expr &x, Args &&...args) const {
17194 std::vector<Expr> collected_args{x, std::forward<Args>(args)...};
17195 return (*this)(collected_args);
17196 }
17197 // @}
17198
17199 /** Creates and returns a new identity Func that wraps this Func. During
17200 * compilation, Halide replaces all calls to this Func done by 'f'
17201 * with calls to the wrapper. If this Func is already wrapped for
17202 * use in 'f', will return the existing wrapper.
17203 *
17204 * For example, g.in(f) would rewrite a pipeline like this:
17205 \code
17206 g(x, y) = ...
17207 f(x, y) = ... g(x, y) ...
17208 \endcode
17209 * into a pipeline like this:
17210 \code
17211 g(x, y) = ...
17212 g_wrap(x, y) = g(x, y)
17213 f(x, y) = ... g_wrap(x, y)
17214 \endcode
17215 *
17216 * This has a variety of uses. You can use it to schedule this
17217 * Func differently in the different places it is used:
17218 \code
17219 g(x, y) = ...
17220 f1(x, y) = ... g(x, y) ...
17221 f2(x, y) = ... g(x, y) ...
17222 g.in(f1).compute_at(f1, y).vectorize(x, 8);
17223 g.in(f2).compute_at(f2, x).unroll(x);
17224 \endcode
17225 *
17226 * You can also use it to stage loads from this Func via some
17227 * intermediate buffer (perhaps on the stack as in
17228 * test/performance/block_transpose.cpp, or in shared GPU memory
17229 * as in test/performance/wrap.cpp). In this we compute the
17230 * wrapper at tiles of the consuming Funcs like so:
17231 \code
17232 g.compute_root()...
17233 g.in(f).compute_at(f, tiles)...
17234 \endcode
17235 *
17236 * Func::in() can also be used to compute pieces of a Func into a
17237 * smaller scratch buffer (perhaps on the GPU) and then copy them
17238 * into a larger output buffer one tile at a time. See
17239 * apps/interpolate/interpolate.cpp for an example of this. In
17240 * this case we compute the Func at tiles of its own wrapper:
17241 \code
17242 f.in(g).compute_root().gpu_tile(...)...
17243 f.compute_at(f.in(g), tiles)...
17244 \endcode
17245 *
17246 * A similar use of Func::in() wrapping Funcs with multiple update
17247 * stages in a pure wrapper. The following code:
17248 \code
17249 f(x, y) = x + y;
17250 f(x, y) += 5;
17251 g(x, y) = f(x, y);
17252 f.compute_root();
17253 \endcode
17254 *
17255 * Is equivalent to:
17256 \code
17257 for y:
17258 for x:
17259 f(x, y) = x + y;
17260 for y:
17261 for x:
17262 f(x, y) += 5
17263 for y:
17264 for x:
17265 g(x, y) = f(x, y)
17266 \endcode
17267 * using Func::in(), we can write:
17268 \code
17269 f(x, y) = x + y;
17270 f(x, y) += 5;
17271 g(x, y) = f(x, y);
17272 f.in(g).compute_root();
17273 \endcode
17274 * which instead produces:
17275 \code
17276 for y:
17277 for x:
17278 f(x, y) = x + y;
17279 f(x, y) += 5
17280 f_wrap(x, y) = f(x, y)
17281 for y:
17282 for x:
17283 g(x, y) = f_wrap(x, y)
17284 \endcode
17285 */
17286 Func in(const Func &f);
17287
17288 /** Create and return an identity wrapper shared by all the Funcs in
17289 * 'fs'. If any of the Funcs in 'fs' already have a custom wrapper,
17290 * this will throw an error. */
17291 Func in(const std::vector<Func> &fs);
17292
17293 /** Create and return a global identity wrapper, which wraps all calls to
17294 * this Func by any other Func. If a global wrapper already exists,
17295 * returns it. The global identity wrapper is only used by callers for
17296 * which no custom wrapper has been specified.
17297 */
17298 Func in();
17299
17300 /** Similar to \ref Func::in; however, instead of replacing the call to
17301 * this Func with an identity Func that refers to it, this replaces the
17302 * call with a clone of this Func.
17303 *
17304 * For example, f.clone_in(g) would rewrite a pipeline like this:
17305 \code
17306 f(x, y) = x + y;
17307 g(x, y) = f(x, y) + 2;
17308 h(x, y) = f(x, y) - 3;
17309 \endcode
17310 * into a pipeline like this:
17311 \code
17312 f(x, y) = x + y;
17313 f_clone(x, y) = x + y;
17314 g(x, y) = f_clone(x, y) + 2;
17315 h(x, y) = f(x, y) - 3;
17316 \endcode
17317 *
17318 */
17319 //@{
17320 Func clone_in(const Func &f);
17321 Func clone_in(const std::vector<Func> &fs);
17322 //@}
17323
17324 /** Declare that this function should be implemented by a call to
17325 * halide_buffer_copy with the given target device API. Asserts
17326 * that the Func has a pure definition which is a simple call to a
17327 * single input, and no update definitions. The wrapper Funcs
17328 * returned by in() are suitable candidates. Consumes all pure
17329 * variables, and rewrites the Func to have an extern definition
17330 * that calls halide_buffer_copy. */
17331 Func copy_to_device(DeviceAPI d = DeviceAPI::Default_GPU);
17332
17333 /** Declare that this function should be implemented by a call to
17334 * halide_buffer_copy with a NULL target device API. Equivalent to
17335 * copy_to_device(DeviceAPI::Host). Asserts that the Func has a
17336 * pure definition which is a simple call to a single input, and
17337 * no update definitions. The wrapper Funcs returned by in() are
17338 * suitable candidates. Consumes all pure variables, and rewrites
17339 * the Func to have an extern definition that calls
17340 * halide_buffer_copy.
17341 *
17342 * Note that if the source Func is already valid in host memory,
17343 * this compiles to code that does the minimum number of calls to
17344 * memcpy.
17345 */
17346 Func copy_to_host();
17347
17348 /** Split a dimension into inner and outer subdimensions with the
17349 * given names, where the inner dimension iterates from 0 to
17350 * factor-1. The inner and outer subdimensions can then be dealt
17351 * with using the other scheduling calls. It's ok to reuse the old
17352 * variable name as either the inner or outer variable. The final
17353 * argument specifies how the tail should be handled if the split
17354 * factor does not provably divide the extent. */
17355 Func &split(const VarOrRVar &old, const VarOrRVar &outer, const VarOrRVar &inner, const Expr &factor, TailStrategy tail = TailStrategy::Auto);
17356
17357 /** Join two dimensions into a single fused dimenion. The fused
17358 * dimension covers the product of the extents of the inner and
17359 * outer dimensions given. */
17360 Func &fuse(const VarOrRVar &inner, const VarOrRVar &outer, const VarOrRVar &fused);
17361
17362 /** Mark a dimension to be traversed serially. This is the default. */
17363 Func &serial(const VarOrRVar &var);
17364
17365 /** Mark a dimension to be traversed in parallel */
17366 Func &parallel(const VarOrRVar &var);
17367
17368 /** Split a dimension by the given task_size, and the parallelize the
17369 * outer dimension. This creates parallel tasks that have size
17370 * task_size. After this call, var refers to the outer dimension of
17371 * the split. The inner dimension has a new anonymous name. If you
17372 * wish to mutate it, or schedule with respect to it, do the split
17373 * manually. */
17374 Func &parallel(const VarOrRVar &var, const Expr &task_size, TailStrategy tail = TailStrategy::Auto);
17375
17376 /** Mark a dimension to be computed all-at-once as a single
17377 * vector. The dimension should have constant extent -
17378 * e.g. because it is the inner dimension following a split by a
17379 * constant factor. For most uses of vectorize you want the two
17380 * argument form. The variable to be vectorized should be the
17381 * innermost one. */
17382 Func &vectorize(const VarOrRVar &var);
17383
17384 /** Mark a dimension to be completely unrolled. The dimension
17385 * should have constant extent - e.g. because it is the inner
17386 * dimension following a split by a constant factor. For most uses
17387 * of unroll you want the two-argument form. */
17388 Func &unroll(const VarOrRVar &var);
17389
17390 /** Split a dimension by the given factor, then vectorize the
17391 * inner dimension. This is how you vectorize a loop of unknown
17392 * size. The variable to be vectorized should be the innermost
17393 * one. After this call, var refers to the outer dimension of the
17394 * split. 'factor' must be an integer. */
17395 Func &vectorize(const VarOrRVar &var, const Expr &factor, TailStrategy tail = TailStrategy::Auto);
17396
17397 /** Split a dimension by the given factor, then unroll the inner
17398 * dimension. This is how you unroll a loop of unknown size by
17399 * some constant factor. After this call, var refers to the outer
17400 * dimension of the split. 'factor' must be an integer. */
17401 Func &unroll(const VarOrRVar &var, const Expr &factor, TailStrategy tail = TailStrategy::Auto);
17402
17403 /** Statically declare that the range over which a function should
17404 * be evaluated is given by the second and third arguments. This
17405 * can let Halide perform some optimizations. E.g. if you know
17406 * there are going to be 4 color channels, you can completely
17407 * vectorize the color channel dimension without the overhead of
17408 * splitting it up. If bounds inference decides that it requires
17409 * more of this function than the bounds you have stated, a
17410 * runtime error will occur when you try to run your pipeline. */
17411 Func &bound(const Var &var, Expr min, Expr extent);
17412
17413 /** Statically declare the range over which the function will be
17414 * evaluated in the general case. This provides a basis for the auto
17415 * scheduler to make trade-offs and scheduling decisions. The auto
17416 * generated schedules might break when the sizes of the dimensions are
17417 * very different from the estimates specified. These estimates are used
17418 * only by the auto scheduler if the function is a pipeline output. */
17419 Func &set_estimate(const Var &var, const Expr &min, const Expr &extent);
17420
17421 /** Set (min, extent) estimates for all dimensions in the Func
17422 * at once; this is equivalent to calling `set_estimate(args()[n], min, extent)`
17423 * repeatedly, but slightly terser. The size of the estimates vector
17424 * must match the dimensionality of the Func. */
17425 Func &set_estimates(const Region &estimates);
17426
17427 /** Expand the region computed so that the min coordinates is
17428 * congruent to 'remainder' modulo 'modulus', and the extent is a
17429 * multiple of 'modulus'. For example, f.align_bounds(x, 2) forces
17430 * the min and extent realized to be even, and calling
17431 * f.align_bounds(x, 2, 1) forces the min to be odd and the extent
17432 * to be even. The region computed always contains the region that
17433 * would have been computed without this directive, so no
17434 * assertions are injected.
17435 */
17436 Func &align_bounds(const Var &var, Expr modulus, Expr remainder = 0);
17437
17438 /** Expand the region computed so that the extent is a
17439 * multiple of 'modulus'. For example, f.align_extent(x, 2) forces
17440 * the extent realized to be even. The region computed always contains the
17441 * region that would have been computed without this directive, so no
17442 * assertions are injected. (This is essentially equivalent to align_bounds(),
17443 * but always leaving the min untouched.)
17444 */
17445 Func &align_extent(const Var &var, Expr modulus);
17446
17447 /** Bound the extent of a Func's realization, but not its
17448 * min. This means the dimension can be unrolled or vectorized
17449 * even when its min is not fixed (for example because it is
17450 * compute_at tiles of another Func). This can also be useful for
17451 * forcing a function's allocation to be a fixed size, which often
17452 * means it can go on the stack. */
17453 Func &bound_extent(const Var &var, Expr extent);
17454
17455 /** Split two dimensions at once by the given factors, and then
17456 * reorder the resulting dimensions to be xi, yi, xo, yo from
17457 * innermost outwards. This gives a tiled traversal. */
17458 Func &tile(const VarOrRVar &x, const VarOrRVar &y,
17459 const VarOrRVar &xo, const VarOrRVar &yo,
17460 const VarOrRVar &xi, const VarOrRVar &yi,
17461 const Expr &xfactor, const Expr &yfactor,
17462 TailStrategy tail = TailStrategy::Auto);
17463
17464 /** A shorter form of tile, which reuses the old variable names as
17465 * the new outer dimensions */
17466 Func &tile(const VarOrRVar &x, const VarOrRVar &y,
17467 const VarOrRVar &xi, const VarOrRVar &yi,
17468 const Expr &xfactor, const Expr &yfactor,
17469 TailStrategy tail = TailStrategy::Auto);
17470
17471 /** A more general form of tile, which defines tiles of any dimensionality. */
17472 Func &tile(const std::vector<VarOrRVar> &previous,
17473 const std::vector<VarOrRVar> &outers,
17474 const std::vector<VarOrRVar> &inners,
17475 const std::vector<Expr> &factors,
17476 const std::vector<TailStrategy> &tails);
17477
17478 /** The generalized tile, with a single tail strategy to apply to all vars. */
17479 Func &tile(const std::vector<VarOrRVar> &previous,
17480 const std::vector<VarOrRVar> &outers,
17481 const std::vector<VarOrRVar> &inners,
17482 const std::vector<Expr> &factors,
17483 TailStrategy tail = TailStrategy::Auto);
17484
17485 /** Generalized tiling, reusing the previous names as the outer names. */
17486 Func &tile(const std::vector<VarOrRVar> &previous,
17487 const std::vector<VarOrRVar> &inners,
17488 const std::vector<Expr> &factors,
17489 TailStrategy tail = TailStrategy::Auto);
17490
17491 /** Reorder variables to have the given nesting order, from
17492 * innermost out */
17493 Func &reorder(const std::vector<VarOrRVar> &vars);
17494
17495 template<typename... Args>
17496 HALIDE_NO_USER_CODE_INLINE typename std::enable_if<Internal::all_are_convertible<VarOrRVar, Args...>::value, Func &>::type
17497 reorder(const VarOrRVar &x, const VarOrRVar &y, Args &&...args) {
17498 std::vector<VarOrRVar> collected_args{x, y, std::forward<Args>(args)...};
17499 return reorder(collected_args);
17500 }
17501
17502 /** Rename a dimension. Equivalent to split with a inner size of one. */
17503 Func &rename(const VarOrRVar &old_name, const VarOrRVar &new_name);
17504
17505 /** Specify that race conditions are permitted for this Func,
17506 * which enables parallelizing over RVars even when Halide cannot
17507 * prove that it is safe to do so. Use this with great caution,
17508 * and only if you can prove to yourself that this is safe, as it
17509 * may result in a non-deterministic routine that returns
17510 * different values at different times or on different machines. */
17511 Func &allow_race_conditions();
17512
17513 /** Issue atomic updates for this Func. This allows parallelization
17514 * on associative RVars. The function throws a compile error when
17515 * Halide fails to prove associativity. Use override_associativity_test
17516 * to disable the associativity test if you believe the function is
17517 * associative or the order of reduction variable execution does not
17518 * matter.
17519 * Halide compiles this into hardware atomic operations whenever possible,
17520 * and falls back to a mutex lock per storage element if it is impossible
17521 * to atomically update.
17522 * There are three possible outcomes of the compiled code:
17523 * atomic add, compare-and-swap loop, and mutex lock.
17524 * For example:
17525 *
17526 * hist(x) = 0;
17527 * hist(im(r)) += 1;
17528 * hist.compute_root();
17529 * hist.update().atomic().parallel();
17530 *
17531 * will be compiled to atomic add operations.
17532 *
17533 * hist(x) = 0;
17534 * hist(im(r)) = min(hist(im(r)) + 1, 100);
17535 * hist.compute_root();
17536 * hist.update().atomic().parallel();
17537 *
17538 * will be compiled to compare-and-swap loops.
17539 *
17540 * arg_max() = {0, im(0)};
17541 * Expr old_index = arg_max()[0];
17542 * Expr old_max = arg_max()[1];
17543 * Expr new_index = select(old_max < im(r), r, old_index);
17544 * Expr new_max = max(im(r), old_max);
17545 * arg_max() = {new_index, new_max};
17546 * arg_max.compute_root();
17547 * arg_max.update().atomic().parallel();
17548 *
17549 * will be compiled to updates guarded by a mutex lock,
17550 * since it is impossible to atomically update two different locations.
17551 *
17552 * Currently the atomic operation is supported by x86, CUDA, and OpenCL backends.
17553 * Compiling to other backends results in a compile error.
17554 * If an operation is compiled into a mutex lock, and is vectorized or is
17555 * compiled to CUDA or OpenCL, it also results in a compile error,
17556 * since per-element mutex lock on vectorized operation leads to a
17557 * deadlock.
17558 * Vectorization of predicated RVars (through rdom.where()) on CPU
17559 * is also unsupported yet (see https://github.com/halide/Halide/issues/4298).
17560 * 8-bit and 16-bit atomics on GPU are also not supported. */
17561 Func &atomic(bool override_associativity_test = false);
17562
17563 /** Specialize a Func. This creates a special-case version of the
17564 * Func where the given condition is true. The most effective
17565 * conditions are those of the form param == value, and boolean
17566 * Params. Consider a simple example:
17567 \code
17568 f(x) = x + select(cond, 0, 1);
17569 f.compute_root();
17570 \endcode
17571 * This is equivalent to:
17572 \code
17573 for (int x = 0; x < width; x++) {
17574 f[x] = x + (cond ? 0 : 1);
17575 }
17576 \endcode
17577 * Adding the scheduling directive:
17578 \code
17579 f.specialize(cond)
17580 \endcode
17581 * makes it equivalent to:
17582 \code
17583 if (cond) {
17584 for (int x = 0; x < width; x++) {
17585 f[x] = x;
17586 }
17587 } else {
17588 for (int x = 0; x < width; x++) {
17589 f[x] = x + 1;
17590 }
17591 }
17592 \endcode
17593 * Note that the inner loops have been simplified. In the first
17594 * path Halide knows that cond is true, and in the second path
17595 * Halide knows that it is false.
17596 *
17597 * The specialized version gets its own schedule, which inherits
17598 * every directive made about the parent Func's schedule so far
17599 * except for its specializations. This method returns a handle to
17600 * the new schedule. If you wish to retrieve the specialized
17601 * sub-schedule again later, you can call this method with the
17602 * same condition. Consider the following example of scheduling
17603 * the specialized version:
17604 *
17605 \code
17606 f(x) = x;
17607 f.compute_root();
17608 f.specialize(width > 1).unroll(x, 2);
17609 \endcode
17610 * Assuming for simplicity that width is even, this is equivalent to:
17611 \code
17612 if (width > 1) {
17613 for (int x = 0; x < width/2; x++) {
17614 f[2*x] = 2*x;
17615 f[2*x + 1] = 2*x + 1;
17616 }
17617 } else {
17618 for (int x = 0; x < width/2; x++) {
17619 f[x] = x;
17620 }
17621 }
17622 \endcode
17623 * For this case, it may be better to schedule the un-specialized
17624 * case instead:
17625 \code
17626 f(x) = x;
17627 f.compute_root();
17628 f.specialize(width == 1); // Creates a copy of the schedule so far.
17629 f.unroll(x, 2); // Only applies to the unspecialized case.
17630 \endcode
17631 * This is equivalent to:
17632 \code
17633 if (width == 1) {
17634 f[0] = 0;
17635 } else {
17636 for (int x = 0; x < width/2; x++) {
17637 f[2*x] = 2*x;
17638 f[2*x + 1] = 2*x + 1;
17639 }
17640 }
17641 \endcode
17642 * This can be a good way to write a pipeline that splits,
17643 * vectorizes, or tiles, but can still handle small inputs.
17644 *
17645 * If a Func has several specializations, the first matching one
17646 * will be used, so the order in which you define specializations
17647 * is significant. For example:
17648 *
17649 \code
17650 f(x) = x + select(cond1, a, b) - select(cond2, c, d);
17651 f.specialize(cond1);
17652 f.specialize(cond2);
17653 \endcode
17654 * is equivalent to:
17655 \code
17656 if (cond1) {
17657 for (int x = 0; x < width; x++) {
17658 f[x] = x + a - (cond2 ? c : d);
17659 }
17660 } else if (cond2) {
17661 for (int x = 0; x < width; x++) {
17662 f[x] = x + b - c;
17663 }
17664 } else {
17665 for (int x = 0; x < width; x++) {
17666 f[x] = x + b - d;
17667 }
17668 }
17669 \endcode
17670 *
17671 * Specializations may in turn be specialized, which creates a
17672 * nested if statement in the generated code.
17673 *
17674 \code
17675 f(x) = x + select(cond1, a, b) - select(cond2, c, d);
17676 f.specialize(cond1).specialize(cond2);
17677 \endcode
17678 * This is equivalent to:
17679 \code
17680 if (cond1) {
17681 if (cond2) {
17682 for (int x = 0; x < width; x++) {
17683 f[x] = x + a - c;
17684 }
17685 } else {
17686 for (int x = 0; x < width; x++) {
17687 f[x] = x + a - d;
17688 }
17689 }
17690 } else {
17691 for (int x = 0; x < width; x++) {
17692 f[x] = x + b - (cond2 ? c : d);
17693 }
17694 }
17695 \endcode
17696 * To create a 4-way if statement that simplifies away all of the
17697 * ternary operators above, you could say:
17698 \code
17699 f.specialize(cond1).specialize(cond2);
17700 f.specialize(cond2);
17701 \endcode
17702 * or
17703 \code
17704 f.specialize(cond1 && cond2);
17705 f.specialize(cond1);
17706 f.specialize(cond2);
17707 \endcode
17708 *
17709 * Any prior Func which is compute_at some variable of this Func
17710 * gets separately included in all paths of the generated if
17711 * statement. The Var in the compute_at call to must exist in all
17712 * paths, but it may have been generated via a different path of
17713 * splits, fuses, and renames. This can be used somewhat
17714 * creatively. Consider the following code:
17715 \code
17716 g(x, y) = 8*x;
17717 f(x, y) = g(x, y) + 1;
17718 f.compute_root().specialize(cond);
17719 Var g_loop;
17720 f.specialize(cond).rename(y, g_loop);
17721 f.rename(x, g_loop);
17722 g.compute_at(f, g_loop);
17723 \endcode
17724 * When cond is true, this is equivalent to g.compute_at(f,y).
17725 * When it is false, this is equivalent to g.compute_at(f,x).
17726 */
17727 Stage specialize(const Expr &condition);
17728
17729 /** Add a specialization to a Func that always terminates execution
17730 * with a call to halide_error(). By itself, this is of limited use,
17731 * but can be useful to terminate chains of specialize() calls where
17732 * no "default" case is expected (thus avoiding unnecessary code generation).
17733 *
17734 * For instance, say we want to optimize a pipeline to process images
17735 * in planar and interleaved format; we might typically do something like:
17736 \code
17737 ImageParam im(UInt(8), 3);
17738 Func f = do_something_with(im);
17739 f.specialize(im.dim(0).stride() == 1).vectorize(x, 8); // planar
17740 f.specialize(im.dim(2).stride() == 1).reorder(c, x, y).vectorize(c); // interleaved
17741 \endcode
17742 * This code will vectorize along rows for the planar case, and across pixel
17743 * components for the interleaved case... but there is an implicit "else"
17744 * for the unhandled cases, which generates unoptimized code. If we never
17745 * anticipate passing any other sort of images to this, we code streamline
17746 * our code by adding specialize_fail():
17747 \code
17748 ImageParam im(UInt(8), 3);
17749 Func f = do_something(im);
17750 f.specialize(im.dim(0).stride() == 1).vectorize(x, 8); // planar
17751 f.specialize(im.dim(2).stride() == 1).reorder(c, x, y).vectorize(c); // interleaved
17752 f.specialize_fail("Unhandled image format");
17753 \endcode
17754 * Conceptually, this produces codes like:
17755 \code
17756 if (im.dim(0).stride() == 1) {
17757 do_something_planar();
17758 } else if (im.dim(2).stride() == 1) {
17759 do_something_interleaved();
17760 } else {
17761 halide_error("Unhandled image format");
17762 }
17763 \endcode
17764 *
17765 * Note that calling specialize_fail() terminates the specialization chain
17766 * for a given Func; you cannot create new specializations for the Func
17767 * afterwards (though you can retrieve handles to previous specializations).
17768 */
17769 void specialize_fail(const std::string &message);
17770
17771 /** Tell Halide that the following dimensions correspond to GPU
17772 * thread indices. This is useful if you compute a producer
17773 * function within the block indices of a consumer function, and
17774 * want to control how that function's dimensions map to GPU
17775 * threads. If the selected target is not an appropriate GPU, this
17776 * just marks those dimensions as parallel. */
17777 // @{
17778 Func &gpu_threads(const VarOrRVar &thread_x, DeviceAPI device_api = DeviceAPI::Default_GPU);
17779 Func &gpu_threads(const VarOrRVar &thread_x, const VarOrRVar &thread_y, DeviceAPI device_api = DeviceAPI::Default_GPU);
17780 Func &gpu_threads(const VarOrRVar &thread_x, const VarOrRVar &thread_y, const VarOrRVar &thread_z, DeviceAPI device_api = DeviceAPI::Default_GPU);
17781 // @}
17782
17783 /** The given dimension corresponds to the lanes in a GPU
17784 * warp. GPU warp lanes are distinguished from GPU threads by the
17785 * fact that all warp lanes run together in lockstep, which
17786 * permits lightweight communication of data from one lane to
17787 * another. */
17788 Func &gpu_lanes(const VarOrRVar &thread_x, DeviceAPI device_api = DeviceAPI::Default_GPU);
17789
17790 /** Tell Halide to run this stage using a single gpu thread and
17791 * block. This is not an efficient use of your GPU, but it can be
17792 * useful to avoid copy-back for intermediate update stages that
17793 * touch a very small part of your Func. */
17794 Func &gpu_single_thread(DeviceAPI device_api = DeviceAPI::Default_GPU);
17795
17796 /** Tell Halide that the following dimensions correspond to GPU
17797 * block indices. This is useful for scheduling stages that will
17798 * run serially within each GPU block. If the selected target is
17799 * not ptx, this just marks those dimensions as parallel. */
17800 // @{
17801 Func &gpu_blocks(const VarOrRVar &block_x, DeviceAPI device_api = DeviceAPI::Default_GPU);
17802 Func &gpu_blocks(const VarOrRVar &block_x, const VarOrRVar &block_y, DeviceAPI device_api = DeviceAPI::Default_GPU);
17803 Func &gpu_blocks(const VarOrRVar &block_x, const VarOrRVar &block_y, const VarOrRVar &block_z, DeviceAPI device_api = DeviceAPI::Default_GPU);
17804 // @}
17805
17806 /** Tell Halide that the following dimensions correspond to GPU
17807 * block indices and thread indices. If the selected target is not
17808 * ptx, these just mark the given dimensions as parallel. The
17809 * dimensions are consumed by this call, so do all other
17810 * unrolling, reordering, etc first. */
17811 // @{
17812 Func &gpu(const VarOrRVar &block_x, const VarOrRVar &thread_x, DeviceAPI device_api = DeviceAPI::Default_GPU);
17813 Func &gpu(const VarOrRVar &block_x, const VarOrRVar &block_y,
17814 const VarOrRVar &thread_x, const VarOrRVar &thread_y, DeviceAPI device_api = DeviceAPI::Default_GPU);
17815 Func &gpu(const VarOrRVar &block_x, const VarOrRVar &block_y, const VarOrRVar &block_z,
17816 const VarOrRVar &thread_x, const VarOrRVar &thread_y, const VarOrRVar &thread_z, DeviceAPI device_api = DeviceAPI::Default_GPU);
17817 // @}
17818
17819 /** Short-hand for tiling a domain and mapping the tile indices
17820 * to GPU block indices and the coordinates within each tile to
17821 * GPU thread indices. Consumes the variables given, so do all
17822 * other scheduling first. */
17823 // @{
17824 Func &gpu_tile(const VarOrRVar &x, const VarOrRVar &bx, const VarOrRVar &tx, const Expr &x_size,
17825 TailStrategy tail = TailStrategy::Auto,
17826 DeviceAPI device_api = DeviceAPI::Default_GPU);
17827
17828 Func &gpu_tile(const VarOrRVar &x, const VarOrRVar &tx, const Expr &x_size,
17829 TailStrategy tail = TailStrategy::Auto,
17830 DeviceAPI device_api = DeviceAPI::Default_GPU);
17831 Func &gpu_tile(const VarOrRVar &x, const VarOrRVar &y,
17832 const VarOrRVar &bx, const VarOrRVar &by,
17833 const VarOrRVar &tx, const VarOrRVar &ty,
17834 const Expr &x_size, const Expr &y_size,
17835 TailStrategy tail = TailStrategy::Auto,
17836 DeviceAPI device_api = DeviceAPI::Default_GPU);
17837
17838 Func &gpu_tile(const VarOrRVar &x, const VarOrRVar &y,
17839 const VarOrRVar &tx, const VarOrRVar &ty,
17840 const Expr &x_size, const Expr &y_size,
17841 TailStrategy tail = TailStrategy::Auto,
17842 DeviceAPI device_api = DeviceAPI::Default_GPU);
17843
17844 Func &gpu_tile(const VarOrRVar &x, const VarOrRVar &y, const VarOrRVar &z,
17845 const VarOrRVar &bx, const VarOrRVar &by, const VarOrRVar &bz,
17846 const VarOrRVar &tx, const VarOrRVar &ty, const VarOrRVar &tz,
17847 const Expr &x_size, const Expr &y_size, const Expr &z_size,
17848 TailStrategy tail = TailStrategy::Auto,
17849 DeviceAPI device_api = DeviceAPI::Default_GPU);
17850 Func &gpu_tile(const VarOrRVar &x, const VarOrRVar &y, const VarOrRVar &z,
17851 const VarOrRVar &tx, const VarOrRVar &ty, const VarOrRVar &tz,
17852 const Expr &x_size, const Expr &y_size, const Expr &z_size,
17853 TailStrategy tail = TailStrategy::Auto,
17854 DeviceAPI device_api = DeviceAPI::Default_GPU);
17855 // @}
17856
17857 /** Schedule for execution on Hexagon. When a loop is marked with
17858 * Hexagon, that loop is executed on a Hexagon DSP. */
17859 Func &hexagon(const VarOrRVar &x = Var::outermost());
17860
17861 /** Prefetch data written to or read from a Func or an ImageParam by a
17862 * subsequent loop iteration, at an optionally specified iteration offset.
17863 * 'var' specifies at which loop level the prefetch calls should be inserted.
17864 * The final argument specifies how prefetch of region outside bounds
17865 * should be handled.
17866 *
17867 * For example, consider this pipeline:
17868 \code
17869 Func f, g;
17870 Var x, y;
17871 f(x, y) = x + y;
17872 g(x, y) = 2 * f(x, y);
17873 \endcode
17874 *
17875 * The following schedule:
17876 \code
17877 f.compute_root();
17878 g.prefetch(f, x, 2, PrefetchBoundStrategy::NonFaulting);
17879 \endcode
17880 *
17881 * will inject prefetch call at the innermost loop of 'g' and generate
17882 * the following loop nest:
17883 * for y = ...
17884 * for x = ...
17885 * f(x, y) = x + y
17886 * for y = ..
17887 * for x = ...
17888 * prefetch(&f[x + 2, y], 1, 16);
17889 * g(x, y) = 2 * f(x, y)
17890 */
17891 // @{
17892 Func &prefetch(const Func &f, const VarOrRVar &var, Expr offset = 1,
17893 PrefetchBoundStrategy strategy = PrefetchBoundStrategy::GuardWithIf);
17894 Func &prefetch(const Internal::Parameter &param, const VarOrRVar &var, Expr offset = 1,
17895 PrefetchBoundStrategy strategy = PrefetchBoundStrategy::GuardWithIf);
17896 template<typename T>
17897 Func &prefetch(const T &image, VarOrRVar var, Expr offset = 1,
17898 PrefetchBoundStrategy strategy = PrefetchBoundStrategy::GuardWithIf) {
17899 return prefetch(image.parameter(), var, offset, strategy);
17900 }
17901 // @}
17902
17903 /** Specify how the storage for the function is laid out. These
17904 * calls let you specify the nesting order of the dimensions. For
17905 * example, foo.reorder_storage(y, x) tells Halide to use
17906 * column-major storage for any realizations of foo, without
17907 * changing how you refer to foo in the code. You may want to do
17908 * this if you intend to vectorize across y. When representing
17909 * color images, foo.reorder_storage(c, x, y) specifies packed
17910 * storage (red, green, and blue values adjacent in memory), and
17911 * foo.reorder_storage(x, y, c) specifies planar storage (entire
17912 * red, green, and blue images one after the other in memory).
17913 *
17914 * If you leave out some dimensions, those remain in the same
17915 * positions in the nesting order while the specified variables
17916 * are reordered around them. */
17917 // @{
17918 Func &reorder_storage(const std::vector<Var> &dims);
17919
17920 Func &reorder_storage(const Var &x, const Var &y);
17921 template<typename... Args>
17922 HALIDE_NO_USER_CODE_INLINE typename std::enable_if<Internal::all_are_convertible<Var, Args...>::value, Func &>::type
17923 reorder_storage(const Var &x, const Var &y, Args &&...args) {
17924 std::vector<Var> collected_args{x, y, std::forward<Args>(args)...};
17925 return reorder_storage(collected_args);
17926 }
17927 // @}
17928
17929 /** Pad the storage extent of a particular dimension of
17930 * realizations of this function up to be a multiple of the
17931 * specified alignment. This guarantees that the strides for the
17932 * dimensions stored outside of dim will be multiples of the
17933 * specified alignment, where the strides and alignment are
17934 * measured in numbers of elements.
17935 *
17936 * For example, to guarantee that a function foo(x, y, c)
17937 * representing an image has scanlines starting on offsets
17938 * aligned to multiples of 16, use foo.align_storage(x, 16). */
17939 Func &align_storage(const Var &dim, const Expr &alignment);
17940
17941 /** Store realizations of this function in a circular buffer of a
17942 * given extent. This is more efficient when the extent of the
17943 * circular buffer is a power of 2. If the fold factor is too
17944 * small, or the dimension is not accessed monotonically, the
17945 * pipeline will generate an error at runtime.
17946 *
17947 * The fold_forward option indicates that the new values of the
17948 * producer are accessed by the consumer in a monotonically
17949 * increasing order. Folding storage of producers is also
17950 * supported if the new values are accessed in a monotonically
17951 * decreasing order by setting fold_forward to false.
17952 *
17953 * For example, consider the pipeline:
17954 \code
17955 Func f, g;
17956 Var x, y;
17957 g(x, y) = x*y;
17958 f(x, y) = g(x, y) + g(x, y+1);
17959 \endcode
17960 *
17961 * If we schedule f like so:
17962 *
17963 \code
17964 g.compute_at(f, y).store_root().fold_storage(y, 2);
17965 \endcode
17966 *
17967 * Then g will be computed at each row of f and stored in a buffer
17968 * with an extent in y of 2, alternately storing each computed row
17969 * of g in row y=0 or y=1.
17970 */
17971 Func &fold_storage(const Var &dim, const Expr &extent, bool fold_forward = true);
17972
17973 /** Compute this function as needed for each unique value of the
17974 * given var for the given calling function f.
17975 *
17976 * For example, consider the simple pipeline:
17977 \code
17978 Func f, g;
17979 Var x, y;
17980 g(x, y) = x*y;
17981 f(x, y) = g(x, y) + g(x, y+1) + g(x+1, y) + g(x+1, y+1);
17982 \endcode
17983 *
17984 * If we schedule f like so:
17985 *
17986 \code
17987 g.compute_at(f, x);
17988 \endcode
17989 *
17990 * Then the C code equivalent to this pipeline will look like this
17991 *
17992 \code
17993
17994 int f[height][width];
17995 for (int y = 0; y < height; y++) {
17996 for (int x = 0; x < width; x++) {
17997 int g[2][2];
17998 g[0][0] = x*y;
17999 g[0][1] = (x+1)*y;
18000 g[1][0] = x*(y+1);
18001 g[1][1] = (x+1)*(y+1);
18002 f[y][x] = g[0][0] + g[1][0] + g[0][1] + g[1][1];
18003 }
18004 }
18005
18006 \endcode
18007 *
18008 * The allocation and computation of g is within f's loop over x,
18009 * and enough of g is computed to satisfy all that f will need for
18010 * that iteration. This has excellent locality - values of g are
18011 * used as soon as they are computed, but it does redundant
18012 * work. Each value of g ends up getting computed four times. If
18013 * we instead schedule f like so:
18014 *
18015 \code
18016 g.compute_at(f, y);
18017 \endcode
18018 *
18019 * The equivalent C code is:
18020 *
18021 \code
18022 int f[height][width];
18023 for (int y = 0; y < height; y++) {
18024 int g[2][width+1];
18025 for (int x = 0; x < width; x++) {
18026 g[0][x] = x*y;
18027 g[1][x] = x*(y+1);
18028 }
18029 for (int x = 0; x < width; x++) {
18030 f[y][x] = g[0][x] + g[1][x] + g[0][x+1] + g[1][x+1];
18031 }
18032 }
18033 \endcode
18034 *
18035 * The allocation and computation of g is within f's loop over y,
18036 * and enough of g is computed to satisfy all that f will need for
18037 * that iteration. This does less redundant work (each point in g
18038 * ends up being evaluated twice), but the locality is not quite
18039 * as good, and we have to allocate more temporary memory to store
18040 * g.
18041 */
18042 Func &compute_at(const Func &f, const Var &var);
18043
18044 /** Schedule a function to be computed within the iteration over
18045 * some dimension of an update domain. Produces equivalent code
18046 * to the version of compute_at that takes a Var. */
18047 Func &compute_at(const Func &f, const RVar &var);
18048
18049 /** Schedule a function to be computed within the iteration over
18050 * a given LoopLevel. */
18051 Func &compute_at(LoopLevel loop_level);
18052
18053 /** Schedule the iteration over the initial definition of this function
18054 * to be fused with another stage 's' from outermost loop to a
18055 * given LoopLevel. */
18056 // @{
18057 Func &compute_with(const Stage &s, const VarOrRVar &var, const std::vector<std::pair<VarOrRVar, LoopAlignStrategy>> &align);
18058 Func &compute_with(const Stage &s, const VarOrRVar &var, LoopAlignStrategy align = LoopAlignStrategy::Auto);
18059 Func &compute_with(LoopLevel loop_level, const std::vector<std::pair<VarOrRVar, LoopAlignStrategy>> &align);
18060 Func &compute_with(LoopLevel loop_level, LoopAlignStrategy align = LoopAlignStrategy::Auto);
18061
18062 /** Compute all of this function once ahead of time. Reusing
18063 * the example in \ref Func::compute_at :
18064 *
18065 \code
18066 Func f, g;
18067 Var x, y;
18068 g(x, y) = x*y;
18069 f(x, y) = g(x, y) + g(x, y+1) + g(x+1, y) + g(x+1, y+1);
18070
18071 g.compute_root();
18072 \endcode
18073 *
18074 * is equivalent to
18075 *
18076 \code
18077 int f[height][width];
18078 int g[height+1][width+1];
18079 for (int y = 0; y < height+1; y++) {
18080 for (int x = 0; x < width+1; x++) {
18081 g[y][x] = x*y;
18082 }
18083 }
18084 for (int y = 0; y < height; y++) {
18085 for (int x = 0; x < width; x++) {
18086 f[y][x] = g[y][x] + g[y+1][x] + g[y][x+1] + g[y+1][x+1];
18087 }
18088 }
18089 \endcode
18090 *
18091 * g is computed once ahead of time, and enough is computed to
18092 * satisfy all uses of it. This does no redundant work (each point
18093 * in g is evaluated once), but has poor locality (values of g are
18094 * probably not still in cache when they are used by f), and
18095 * allocates lots of temporary memory to store g.
18096 */
18097 Func &compute_root();
18098
18099 /** Use the halide_memoization_cache_... interface to store a
18100 * computed version of this function across invocations of the
18101 * Func.
18102 *
18103 * If an eviction_key is provided, it must be constructed with
18104 * Expr of integer or handle type. The key Expr will be promoted
18105 * to a uint64_t and can be used with halide_memoization_cache_evict
18106 * to remove memoized entries using this eviction key from the
18107 * cache. Memoized computations that do not provide an eviction
18108 * key will never be evicted by this mechanism.
18109 */
18110 Func &memoize(const EvictionKey &eviction_key = EvictionKey());
18111
18112 /** Produce this Func asynchronously in a separate
18113 * thread. Consumers will be run by the task system when the
18114 * production is complete. If this Func's store level is different
18115 * to its compute level, consumers will be run concurrently,
18116 * blocking as necessary to prevent reading ahead of what the
18117 * producer has computed. If storage is folded, then the producer
18118 * will additionally not be permitted to run too far ahead of the
18119 * consumer, to avoid clobbering data that has not yet been
18120 * used.
18121 *
18122 * Take special care when combining this with custom thread pool
18123 * implementations, as avoiding deadlock with producer-consumer
18124 * parallelism requires a much more sophisticated parallel runtime
18125 * than with data parallelism alone. It is strongly recommended
18126 * you just use Halide's default thread pool, which guarantees no
18127 * deadlock and a bound on the number of threads launched.
18128 */
18129 Func &async();
18130
18131 /** Allocate storage for this function within f's loop over
18132 * var. Scheduling storage is optional, and can be used to
18133 * separate the loop level at which storage occurs from the loop
18134 * level at which computation occurs to trade off between locality
18135 * and redundant work. This can open the door for two types of
18136 * optimization.
18137 *
18138 * Consider again the pipeline from \ref Func::compute_at :
18139 \code
18140 Func f, g;
18141 Var x, y;
18142 g(x, y) = x*y;
18143 f(x, y) = g(x, y) + g(x+1, y) + g(x, y+1) + g(x+1, y+1);
18144 \endcode
18145 *
18146 * If we schedule it like so:
18147 *
18148 \code
18149 g.compute_at(f, x).store_at(f, y);
18150 \endcode
18151 *
18152 * Then the computation of g takes place within the loop over x,
18153 * but the storage takes place within the loop over y:
18154 *
18155 \code
18156 int f[height][width];
18157 for (int y = 0; y < height; y++) {
18158 int g[2][width+1];
18159 for (int x = 0; x < width; x++) {
18160 g[0][x] = x*y;
18161 g[0][x+1] = (x+1)*y;
18162 g[1][x] = x*(y+1);
18163 g[1][x+1] = (x+1)*(y+1);
18164 f[y][x] = g[0][x] + g[1][x] + g[0][x+1] + g[1][x+1];
18165 }
18166 }
18167 \endcode
18168 *
18169 * Provided the for loop over x is serial, halide then
18170 * automatically performs the following sliding window
18171 * optimization:
18172 *
18173 \code
18174 int f[height][width];
18175 for (int y = 0; y < height; y++) {
18176 int g[2][width+1];
18177 for (int x = 0; x < width; x++) {
18178 if (x == 0) {
18179 g[0][x] = x*y;
18180 g[1][x] = x*(y+1);
18181 }
18182 g[0][x+1] = (x+1)*y;
18183 g[1][x+1] = (x+1)*(y+1);
18184 f[y][x] = g[0][x] + g[1][x] + g[0][x+1] + g[1][x+1];
18185 }
18186 }
18187 \endcode
18188 *
18189 * Two of the assignments to g only need to be done when x is
18190 * zero. The rest of the time, those sites have already been
18191 * filled in by a previous iteration. This version has the
18192 * locality of compute_at(f, x), but allocates more memory and
18193 * does much less redundant work.
18194 *
18195 * Halide then further optimizes this pipeline like so:
18196 *
18197 \code
18198 int f[height][width];
18199 for (int y = 0; y < height; y++) {
18200 int g[2][2];
18201 for (int x = 0; x < width; x++) {
18202 if (x == 0) {
18203 g[0][0] = x*y;
18204 g[1][0] = x*(y+1);
18205 }
18206 g[0][(x+1)%2] = (x+1)*y;
18207 g[1][(x+1)%2] = (x+1)*(y+1);
18208 f[y][x] = g[0][x%2] + g[1][x%2] + g[0][(x+1)%2] + g[1][(x+1)%2];
18209 }
18210 }
18211 \endcode
18212 *
18213 * Halide has detected that it's possible to use a circular buffer
18214 * to represent g, and has reduced all accesses to g modulo 2 in
18215 * the x dimension. This optimization only triggers if the for
18216 * loop over x is serial, and if halide can statically determine
18217 * some power of two large enough to cover the range needed. For
18218 * powers of two, the modulo operator compiles to more efficient
18219 * bit-masking. This optimization reduces memory usage, and also
18220 * improves locality by reusing recently-accessed memory instead
18221 * of pulling new memory into cache.
18222 *
18223 */
18224 Func &store_at(const Func &f, const Var &var);
18225
18226 /** Equivalent to the version of store_at that takes a Var, but
18227 * schedules storage within the loop over a dimension of a
18228 * reduction domain */
18229 Func &store_at(const Func &f, const RVar &var);
18230
18231 /** Equivalent to the version of store_at that takes a Var, but
18232 * schedules storage at a given LoopLevel. */
18233 Func &store_at(LoopLevel loop_level);
18234
18235 /** Equivalent to \ref Func::store_at, but schedules storage
18236 * outside the outermost loop. */
18237 Func &store_root();
18238
18239 /** Aggressively inline all uses of this function. This is the
18240 * default schedule, so you're unlikely to need to call this. For
18241 * a Func with an update definition, that means it gets computed
18242 * as close to the innermost loop as possible.
18243 *
18244 * Consider once more the pipeline from \ref Func::compute_at :
18245 *
18246 \code
18247 Func f, g;
18248 Var x, y;
18249 g(x, y) = x*y;
18250 f(x, y) = g(x, y) + g(x+1, y) + g(x, y+1) + g(x+1, y+1);
18251 \endcode
18252 *
18253 * Leaving g as inline, this compiles to code equivalent to the following C:
18254 *
18255 \code
18256 int f[height][width];
18257 for (int y = 0; y < height; y++) {
18258 for (int x = 0; x < width; x++) {
18259 f[y][x] = x*y + x*(y+1) + (x+1)*y + (x+1)*(y+1);
18260 }
18261 }
18262 \endcode
18263 */
18264 Func &compute_inline();
18265
18266 /** Get a handle on an update step for the purposes of scheduling
18267 * it. */
18268 Stage update(int idx = 0);
18269
18270 /** Set the type of memory this Func should be stored in. Controls
18271 * whether allocations go on the stack or the heap on the CPU, and
18272 * in global vs shared vs local on the GPU. See the documentation
18273 * on MemoryType for more detail. */
18274 Func &store_in(MemoryType memory_type);
18275
18276 /** Trace all loads from this Func by emitting calls to
18277 * halide_trace. If the Func is inlined, this has no
18278 * effect. */
18279 Func &trace_loads();
18280
18281 /** Trace all stores to the buffer backing this Func by emitting
18282 * calls to halide_trace. If the Func is inlined, this call
18283 * has no effect. */
18284 Func &trace_stores();
18285
18286 /** Trace all realizations of this Func by emitting calls to
18287 * halide_trace. */
18288 Func &trace_realizations();
18289
18290 /** Add a string of arbitrary text that will be passed thru to trace
18291 * inspection code if the Func is realized in trace mode. (Funcs that are
18292 * inlined won't have their tags emitted.) Ignored entirely if
18293 * tracing is not enabled for the Func (or globally).
18294 */
18295 Func &add_trace_tag(const std::string &trace_tag);
18296
18297 /** Get a handle on the internal halide function that this Func
18298 * represents. Useful if you want to do introspection on Halide
18299 * functions */
18300 Internal::Function function() const {
18301 return func;
18302 }
18303
18304 /** You can cast a Func to its pure stage for the purposes of
18305 * scheduling it. */
18306 operator Stage() const;
18307
18308 /** Get a handle on the output buffer for this Func. Only relevant
18309 * if this is the output Func in a pipeline. Useful for making
18310 * static promises about strides, mins, and extents. */
18311 // @{
18312 OutputImageParam output_buffer() const;
18313 std::vector<OutputImageParam> output_buffers() const;
18314 // @}
18315
18316 /** Use a Func as an argument to an external stage. */
18317 operator ExternFuncArgument() const;
18318
18319 /** Infer the arguments to the Func, sorted into a canonical order:
18320 * all buffers (sorted alphabetically by name), followed by all non-buffers
18321 * (sorted alphabetically by name).
18322 This lets you write things like:
18323 \code
18324 func.compile_to_assembly("/dev/stdout", func.infer_arguments());
18325 \endcode
18326 */
18327 std::vector<Argument> infer_arguments() const;
18328
18329 /** Get the source location of the pure definition of this
18330 * Func. See Stage::source_location() */
18331 std::string source_location() const;
18332
18333 /** Return the current StageSchedule associated with this initial
18334 * Stage of this Func. For introspection only: to modify schedule,
18335 * use the Func interface. */
18336 const Internal::StageSchedule &get_schedule() const {
18337 return Stage(*this).get_schedule();
18338 }
18339};
18340
18341namespace Internal {
18342
18343template<typename Last>
18344inline void check_types(const Tuple &t, int idx) {
18345 using T = typename std::remove_pointer<typename std::remove_reference<Last>::type>::type;
18346 user_assert(t[idx].type() == type_of<T>())
18347 << "Can't evaluate expression "
18348 << t[idx] << " of type " << t[idx].type()
18349 << " as a scalar of type " << type_of<T>() << "\n";
18350}
18351
18352template<typename First, typename Second, typename... Rest>
18353inline void check_types(const Tuple &t, int idx) {
18354 check_types<First>(t, idx);
18355 check_types<Second, Rest...>(t, idx + 1);
18356}
18357
18358template<typename Last>
18359inline void assign_results(Realization &r, int idx, Last last) {
18360 using T = typename std::remove_pointer<typename std::remove_reference<Last>::type>::type;
18361 *last = Buffer<T>(r[idx])();
18362}
18363
18364template<typename First, typename Second, typename... Rest>
18365inline void assign_results(Realization &r, int idx, First first, Second second, Rest &&...rest) {
18366 assign_results<First>(r, idx, first);
18367 assign_results<Second, Rest...>(r, idx + 1, second, rest...);
18368}
18369
18370} // namespace Internal
18371
18372/** JIT-Compile and run enough code to evaluate a Halide
18373 * expression. This can be thought of as a scalar version of
18374 * \ref Func::realize */
18375template<typename T>
18376HALIDE_NO_USER_CODE_INLINE T evaluate(const Expr &e) {
18377 user_assert(e.type() == type_of<T>())
18378 << "Can't evaluate expression "
18379 << e << " of type " << e.type()
18380 << " as a scalar of type " << type_of<T>() << "\n";
18381 Func f;
18382 f() = e;
18383 Buffer<T> im = f.realize();
18384 return im();
18385}
18386
18387/** JIT-compile and run enough code to evaluate a Halide Tuple. */
18388template<typename First, typename... Rest>
18389HALIDE_NO_USER_CODE_INLINE void evaluate(Tuple t, First first, Rest &&...rest) {
18390 Internal::check_types<First, Rest...>(t, 0);
18391
18392 Func f;
18393 f() = t;
18394 Realization r = f.realize();
18395 Internal::assign_results(r, 0, first, rest...);
18396}
18397
18398namespace Internal {
18399
18400inline void schedule_scalar(Func f) {
18401 Target t = get_jit_target_from_environment();
18402 if (t.has_gpu_feature()) {
18403 f.gpu_single_thread();
18404 }
18405 if (t.has_feature(Target::HVX)) {
18406 f.hexagon();
18407 }
18408}
18409
18410} // namespace Internal
18411
18412/** JIT-Compile and run enough code to evaluate a Halide
18413 * expression. This can be thought of as a scalar version of
18414 * \ref Func::realize. Can use GPU if jit target from environment
18415 * specifies one.
18416 */
18417template<typename T>
18418HALIDE_NO_USER_CODE_INLINE T evaluate_may_gpu(const Expr &e) {
18419 user_assert(e.type() == type_of<T>())
18420 << "Can't evaluate expression "
18421 << e << " of type " << e.type()
18422 << " as a scalar of type " << type_of<T>() << "\n";
18423 Func f;
18424 f() = e;
18425 Internal::schedule_scalar(f);
18426 Buffer<T> im = f.realize();
18427 return im();
18428}
18429
18430/** JIT-compile and run enough code to evaluate a Halide Tuple. Can
18431 * use GPU if jit target from environment specifies one. */
18432// @{
18433template<typename First, typename... Rest>
18434HALIDE_NO_USER_CODE_INLINE void evaluate_may_gpu(Tuple t, First first, Rest &&...rest) {
18435 Internal::check_types<First, Rest...>(t, 0);
18436
18437 Func f;
18438 f() = t;
18439 Internal::schedule_scalar(f);
18440 Realization r = f.realize();
18441 Internal::assign_results(r, 0, first, rest...);
18442}
18443// @}
18444
18445} // namespace Halide
18446
18447#endif
18448#ifndef HALIDE_LAMBDA_H
18449#define HALIDE_LAMBDA_H
18450
18451
18452/** \file
18453 * Convenience functions for creating small anonymous Halide
18454 * functions. See test/lambda.cpp for example usage. */
18455
18456namespace Halide {
18457
18458/** Create a zero-dimensional halide function that returns the given
18459 * expression. The function may have more dimensions if the expression
18460 * contains implicit arguments. */
18461Func lambda(const Expr &e);
18462
18463/** Create a 1-D halide function in the first argument that returns
18464 * the second argument. The function may have more dimensions if the
18465 * expression contains implicit arguments and the list of Var
18466 * arguments contains a placeholder ("_"). */
18467Func lambda(const Var &x, const Expr &e);
18468
18469/** Create a 2-D halide function in the first two arguments that
18470 * returns the last argument. The function may have more dimensions if
18471 * the expression contains implicit arguments and the list of Var
18472 * arguments contains a placeholder ("_"). */
18473Func lambda(const Var &x, const Var &y, const Expr &e);
18474
18475/** Create a 3-D halide function in the first three arguments that
18476 * returns the last argument. The function may have more dimensions
18477 * if the expression contains implicit arguments and the list of Var
18478 * arguments contains a placeholder ("_"). */
18479Func lambda(const Var &x, const Var &y, const Var &z, const Expr &e);
18480
18481/** Create a 4-D halide function in the first four arguments that
18482 * returns the last argument. The function may have more dimensions if
18483 * the expression contains implicit arguments and the list of Var
18484 * arguments contains a placeholder ("_"). */
18485Func lambda(const Var &x, const Var &y, const Var &z, const Var &w, const Expr &e);
18486
18487/** Create a 5-D halide function in the first five arguments that
18488 * returns the last argument. The function may have more dimensions if
18489 * the expression contains implicit arguments and the list of Var
18490 * arguments contains a placeholder ("_"). */
18491Func lambda(const Var &x, const Var &y, const Var &z, const Var &w, const Var &v, const Expr &e);
18492
18493} // namespace Halide
18494
18495#endif //HALIDE_LAMBDA_H
18496
18497namespace Halide {
18498
18499/** namespace to hold functions for imposing boundary conditions on
18500 * Halide Funcs.
18501 *
18502 * All functions in this namespace transform a source Func to a
18503 * result Func where the result produces the values of the source
18504 * within a given region and a different set of values outside the
18505 * given region. A region is an N dimensional box specified by
18506 * mins and extents.
18507 *
18508 * Three areas are defined:
18509 * The image is the entire set of values in the region.
18510 * The edge is the set of pixels in the image but adjacent
18511 * to coordinates that are not
18512 * The interior is the image minus the edge (and is undefined
18513 * if the extent of any region is 1 or less).
18514 *
18515 * If the source Func has more dimensions than are specified, the extra ones
18516 * are unmodified. Additionally, passing an undefined (default constructed)
18517 * 'Expr' for the min and extent of a dimension will keep that dimension
18518 * unmodified.
18519 *
18520 * Numerous options for specifing the outside area are provided,
18521 * including replacement with an expression, repeating the edge
18522 * samples, mirroring over the edge, and repeating or mirroring the
18523 * entire image.
18524 *
18525 * Using these functions to express your boundary conditions is highly
18526 * recommended for correctness and performance. Some of these are hard
18527 * to get right. The versions here are both understood by bounds
18528 * inference, and also judiciously use the 'likely' intrinsic to minimize
18529 * runtime overhead.
18530 *
18531 */
18532namespace BoundaryConditions {
18533
18534namespace Internal {
18535
18536inline HALIDE_NO_USER_CODE_INLINE void collect_region(Region &collected_args,
18537 const Expr &a1, const Expr &a2) {
18538 collected_args.push_back(Range(a1, a2));
18539}
18540
18541template<typename... Args>
18542inline HALIDE_NO_USER_CODE_INLINE void collect_region(Region &collected_args,
18543 const Expr &a1, const Expr &a2, Args &&...args) {
18544 collected_args.push_back(Range(a1, a2));
18545 collect_region(collected_args, std::forward<Args>(args)...);
18546}
18547
18548inline const Func &func_like_to_func(const Func &func) {
18549 return func;
18550}
18551
18552template<typename T>
18553inline HALIDE_NO_USER_CODE_INLINE Func func_like_to_func(const T &func_like) {
18554 return lambda(_, func_like(_));
18555}
18556
18557} // namespace Internal
18558
18559/** Impose a boundary condition such that a given expression is returned
18560 * everywhere outside the boundary. Generally the expression will be a
18561 * constant, though the code currently allows accessing the arguments
18562 * of source.
18563 *
18564 * An ImageParam, Buffer<T>, or similar can be passed instead of a
18565 * Func. If this is done and no bounds are given, the boundaries will
18566 * be taken from the min and extent methods of the passed
18567 * object. Note that objects are taken by mutable ref. Pipelines
18568 * capture Buffers via mutable refs, because running a pipeline might
18569 * alter the Buffer metadata (e.g. device allocation state).
18570 *
18571 * (This is similar to setting GL_TEXTURE_WRAP_* to GL_CLAMP_TO_BORDER
18572 * and putting value in the border of the texture.)
18573 *
18574 * You may pass undefined Exprs for dimensions that you do not wish
18575 * to bound.
18576 */
18577// @{
18578Func constant_exterior(const Func &source, const Tuple &value,
18579 const Region &bounds);
18580Func constant_exterior(const Func &source, const Expr &value,
18581 const Region &bounds);
18582
18583template<typename T>
18584HALIDE_NO_USER_CODE_INLINE Func constant_exterior(const T &func_like, const Tuple &value, const Region &bounds) {
18585 return constant_exterior(Internal::func_like_to_func(func_like), value, bounds);
18586}
18587
18588template<typename T>
18589HALIDE_NO_USER_CODE_INLINE Func constant_exterior(const T &func_like, const Expr &value, const Region &bounds) {
18590 return constant_exterior(Internal::func_like_to_func(func_like), value, bounds);
18591}
18592
18593template<typename T>
18594HALIDE_NO_USER_CODE_INLINE Func constant_exterior(const T &func_like, const Tuple &value) {
18595 Region object_bounds;
18596 for (int i = 0; i < func_like.dimensions(); i++) {
18597 object_bounds.push_back({Expr(func_like.dim(i).min()), Expr(func_like.dim(i).extent())});
18598 }
18599
18600 return constant_exterior(Internal::func_like_to_func(func_like), value, object_bounds);
18601}
18602template<typename T>
18603HALIDE_NO_USER_CODE_INLINE Func constant_exterior(const T &func_like, const Expr &value) {
18604 return constant_exterior(func_like, Tuple(value));
18605}
18606
18607template<typename T, typename... Bounds,
18608 typename std::enable_if<Halide::Internal::all_are_convertible<Expr, Bounds...>::value>::type * = nullptr>
18609HALIDE_NO_USER_CODE_INLINE Func constant_exterior(const T &func_like, const Tuple &value,
18610 Bounds &&...bounds) {
18611 Region collected_bounds;
18612 Internal::collect_region(collected_bounds, std::forward<Bounds>(bounds)...);
18613 return constant_exterior(Internal::func_like_to_func(func_like), value, collected_bounds);
18614}
18615template<typename T, typename... Bounds,
18616 typename std::enable_if<Halide::Internal::all_are_convertible<Expr, Bounds...>::value>::type * = nullptr>
18617HALIDE_NO_USER_CODE_INLINE Func constant_exterior(const T &func_like, const Expr &value,
18618 Bounds &&...bounds) {
18619 return constant_exterior(func_like, Tuple(value), std::forward<Bounds>(bounds)...);
18620}
18621// @}
18622
18623/** Impose a boundary condition such that the nearest edge sample is returned
18624 * everywhere outside the given region.
18625 *
18626 * An ImageParam, Buffer<T>, or similar can be passed instead of a Func. If this
18627 * is done and no bounds are given, the boundaries will be taken from the
18628 * min and extent methods of the passed object.
18629 *
18630 * (This is similar to setting GL_TEXTURE_WRAP_* to GL_CLAMP_TO_EDGE.)
18631 *
18632 * You may pass undefined Exprs for dimensions that you do not wish
18633 * to bound.
18634 */
18635// @{
18636Func repeat_edge(const Func &source, const Region &bounds);
18637
18638template<typename T>
18639HALIDE_NO_USER_CODE_INLINE Func repeat_edge(const T &func_like, const Region &bounds) {
18640 return repeat_edge(Internal::func_like_to_func(func_like), bounds);
18641}
18642
18643template<typename T>
18644HALIDE_NO_USER_CODE_INLINE Func repeat_edge(const T &func_like) {
18645 Region object_bounds;
18646 for (int i = 0; i < func_like.dimensions(); i++) {
18647 object_bounds.push_back({Expr(func_like.dim(i).min()), Expr(func_like.dim(i).extent())});
18648 }
18649
18650 return repeat_edge(Internal::func_like_to_func(func_like), object_bounds);
18651}
18652// @}
18653
18654/** Impose a boundary condition such that the entire coordinate space is
18655 * tiled with copies of the image abutted against each other.
18656 *
18657 * An ImageParam, Buffer<T>, or similar can be passed instead of a Func. If this
18658 * is done and no bounds are given, the boundaries will be taken from the
18659 * min and extent methods of the passed object.
18660 *
18661 * (This is similar to setting GL_TEXTURE_WRAP_* to GL_REPEAT.)
18662 *
18663 * You may pass undefined Exprs for dimensions that you do not wish
18664 * to bound.
18665 */
18666// @{
18667Func repeat_image(const Func &source, const Region &bounds);
18668
18669template<typename T>
18670HALIDE_NO_USER_CODE_INLINE Func repeat_image(const T &func_like, const Region &bounds) {
18671 return repeat_image(Internal::func_like_to_func(func_like), bounds);
18672}
18673
18674template<typename T>
18675HALIDE_NO_USER_CODE_INLINE Func repeat_image(const T &func_like) {
18676 Region object_bounds;
18677 for (int i = 0; i < func_like.dimensions(); i++) {
18678 object_bounds.push_back({Expr(func_like.dim(i).min()), Expr(func_like.dim(i).extent())});
18679 }
18680
18681 return repeat_image(Internal::func_like_to_func(func_like), object_bounds);
18682}
18683
18684/** Impose a boundary condition such that the entire coordinate space is
18685 * tiled with copies of the image abutted against each other, but mirror
18686 * them such that adjacent edges are the same.
18687 *
18688 * An ImageParam, Buffer<T>, or similar can be passed instead of a Func. If this
18689 * is done and no bounds are given, the boundaries will be taken from the
18690 * min and extent methods of the passed object.
18691 *
18692 * (This is similar to setting GL_TEXTURE_WRAP_* to GL_MIRRORED_REPEAT.)
18693 *
18694 * You may pass undefined Exprs for dimensions that you do not wish
18695 * to bound.
18696 */
18697// @{
18698Func mirror_image(const Func &source, const Region &bounds);
18699
18700template<typename T>
18701HALIDE_NO_USER_CODE_INLINE Func mirror_image(const T &func_like, const Region &bounds) {
18702 return mirror_image(Internal::func_like_to_func(func_like), bounds);
18703}
18704
18705template<typename T>
18706HALIDE_NO_USER_CODE_INLINE Func mirror_image(const T &func_like) {
18707 Region object_bounds;
18708 for (int i = 0; i < func_like.dimensions(); i++) {
18709 object_bounds.push_back({Expr(func_like.dim(i).min()), Expr(func_like.dim(i).extent())});
18710 }
18711
18712 return mirror_image(Internal::func_like_to_func(func_like), object_bounds);
18713}
18714
18715// @}
18716
18717/** Impose a boundary condition such that the entire coordinate space is
18718 * tiled with copies of the image abutted against each other, but mirror
18719 * them such that adjacent edges are the same and then overlap the edges.
18720 *
18721 * This produces an error if any extent is 1 or less. (TODO: check this.)
18722 *
18723 * An ImageParam, Buffer<T>, or similar can be passed instead of a Func. If this
18724 * is done and no bounds are given, the boundaries will be taken from the
18725 * min and extent methods of the passed object.
18726 *
18727 * (I do not believe there is a direct GL_TEXTURE_WRAP_* equivalent for this.)
18728 *
18729 * You may pass undefined Exprs for dimensions that you do not wish
18730 * to bound.
18731 */
18732// @{
18733Func mirror_interior(const Func &source, const Region &bounds);
18734
18735template<typename T>
18736HALIDE_NO_USER_CODE_INLINE Func mirror_interior(const T &func_like, const Region &bounds) {
18737 return mirror_interior(Internal::func_like_to_func(func_like), bounds);
18738}
18739
18740template<typename T>
18741HALIDE_NO_USER_CODE_INLINE Func mirror_interior(const T &func_like) {
18742 Region object_bounds;
18743 for (int i = 0; i < func_like.dimensions(); i++) {
18744 object_bounds.push_back({Expr(func_like.dim(i).min()), Expr(func_like.dim(i).extent())});
18745 }
18746
18747 return mirror_interior(Internal::func_like_to_func(func_like), object_bounds);
18748}
18749
18750// @}
18751
18752} // namespace BoundaryConditions
18753
18754} // namespace Halide
18755
18756#endif
18757#ifndef HALIDE_BOUNDS_INFERENCE_H
18758#define HALIDE_BOUNDS_INFERENCE_H
18759
18760/** \file
18761 * Defines the bounds_inference lowering pass.
18762 */
18763
18764#include <map>
18765#include <string>
18766#include <vector>
18767
18768
18769namespace Halide {
18770
18771struct Target;
18772
18773namespace Internal {
18774
18775class Function;
18776
18777/** Take a partially lowered statement that includes symbolic
18778 * representations of the bounds over which things should be realized,
18779 * and inject expressions defining those bounds.
18780 */
18781Stmt bounds_inference(Stmt,
18782 const std::vector<Function> &outputs,
18783 const std::vector<std::string> &realization_order,
18784 const std::vector<std::vector<std::string>> &fused_groups,
18785 const std::map<std::string, Function> &environment,
18786 const std::map<std::pair<std::string, int>, Interval> &func_bounds,
18787 const Target &target);
18788
18789} // namespace Internal
18790} // namespace Halide
18791
18792#endif
18793#ifndef HALIDE_BOUND_SMALL_ALLOCATIONS
18794#define HALIDE_BOUND_SMALL_ALLOCATIONS
18795
18796
18797/** \file
18798 * Defines the lowering pass that attempts to rewrite small
18799 * allocations to have constant size.
18800 */
18801
18802namespace Halide {
18803namespace Internal {
18804
18805/** \file
18806 *
18807 * Use bounds analysis to attempt to bound the sizes of small
18808 * allocations. Inside GPU kernels this is necessary in order to
18809 * compile. On the CPU this is also useful, because it prevents malloc
18810 * calls for (provably) tiny allocations. */
18811Stmt bound_small_allocations(const Stmt &s);
18812
18813} // namespace Internal
18814} // namespace Halide
18815
18816#endif
18817#ifndef HALIDE_CANONICALIZE_GPU_VARS_H
18818#define HALIDE_CANONICALIZE_GPU_VARS_H
18819
18820/** \file
18821 * Defines the lowering pass that canonicalize the GPU var names over.
18822 */
18823
18824
18825namespace Halide {
18826namespace Internal {
18827
18828/** Canonicalize GPU var names into some pre-determined block/thread names
18829 * (i.e. __block_id_x, __thread_id_x, etc.). The x/y/z/w order is determined
18830 * by the nesting order: innermost is assigned to x and so on. */
18831Stmt canonicalize_gpu_vars(Stmt s);
18832
18833} // namespace Internal
18834} // namespace Halide
18835
18836#endif
18837#ifndef HALIDE_CLOSURE_H
18838#define HALIDE_CLOSURE_H
18839
18840/** \file
18841 *
18842 * Provides Closure class.
18843 */
18844#include <map>
18845#include <string>
18846
18847
18848namespace Halide {
18849
18850template<typename T>
18851class Buffer;
18852
18853namespace Internal {
18854
18855/** A helper class to manage closures. Walks over a statement and
18856 * retrieves all the references within it to external symbols
18857 * (variables and allocations). It then helps you build a struct
18858 * containing the current values of these symbols that you can use as
18859 * a closure if you want to migrate the body of the statement to its
18860 * own function (e.g. because it's the body of a parallel for loop. */
18861class Closure : public IRVisitor {
18862protected:
18863 Scope<> ignore;
18864
18865 using IRVisitor::visit;
18866
18867 void visit(const Let *op) override;
18868 void visit(const LetStmt *op) override;
18869 void visit(const For *op) override;
18870 void visit(const Load *op) override;
18871 void visit(const Store *op) override;
18872 void visit(const Allocate *op) override;
18873 void visit(const Variable *op) override;
18874 void visit(const Atomic *op) override;
18875
18876public:
18877 /** Information about a buffer reference from a closure. */
18878 struct Buffer {
18879 /** The type of the buffer referenced. */
18880 Type type;
18881
18882 /** The dimensionality of the buffer. */
18883 uint8_t dimensions = 0;
18884
18885 /** The buffer is read from. */
18886 bool read = false;
18887
18888 /** The buffer is written to. */
18889 bool write = false;
18890
18891 /** The buffer is a texture */
18892 MemoryType memory_type = MemoryType::Auto;
18893
18894 /** The size of the buffer if known, otherwise zero. */
18895 size_t size = 0;
18896
18897 Buffer() = default;
18898 };
18899
18900protected:
18901 void found_buffer_ref(const std::string &name, Type type,
18902 bool read, bool written, const Halide::Buffer<void> &image);
18903
18904public:
18905 Closure() = default;
18906
18907 /** Traverse a statement and find all references to external
18908 * symbols.
18909 *
18910 * When the closure encounters a read or write to 'foo', it
18911 * assumes that the host pointer is found in the symbol table as
18912 * 'foo.host', and any halide_buffer_t pointer is found under
18913 * 'foo.buffer'. */
18914 Closure(const Stmt &s, const std::string &loop_variable = "");
18915
18916 /** External variables referenced. */
18917 std::map<std::string, Type> vars;
18918
18919 /** External allocations referenced. */
18920 std::map<std::string, Buffer> buffers;
18921};
18922
18923} // namespace Internal
18924} // namespace Halide
18925
18926#endif
18927#ifndef HALIDE_CODEGEN_C_H
18928#define HALIDE_CODEGEN_C_H
18929
18930/** \file
18931 *
18932 * Defines an IRPrinter that emits C++ code equivalent to a halide stmt
18933 */
18934
18935#ifndef HALIDE_IR_PRINTER_H
18936#define HALIDE_IR_PRINTER_H
18937
18938/** \file
18939 * This header file defines operators that let you dump a Halide
18940 * expression, statement, or type directly into an output stream
18941 * in a human readable form.
18942 * E.g:
18943 \code
18944 Expr foo = ...
18945 std::cout << "Foo is " << foo << "\n";
18946 \endcode
18947 *
18948 * These operators are implemented using \ref Halide::Internal::IRPrinter
18949 */
18950
18951#include <ostream>
18952
18953
18954namespace Halide {
18955
18956/** Emit an expression on an output stream (such as std::cout) in
18957 * human-readable form */
18958std::ostream &operator<<(std::ostream &stream, const Expr &);
18959
18960/** Emit a halide type on an output stream (such as std::cout) in
18961 * human-readable form */
18962std::ostream &operator<<(std::ostream &stream, const Type &);
18963
18964/** Emit a halide Module on an output stream (such as std::cout) in
18965 * human-readable form */
18966std::ostream &operator<<(std::ostream &stream, const Module &);
18967
18968/** Emit a halide device api type in human-readable form */
18969std::ostream &operator<<(std::ostream &stream, const DeviceAPI &);
18970
18971/** Emit a halide memory type in human-readable form */
18972std::ostream &operator<<(std::ostream &stream, const MemoryType &);
18973
18974/** Emit a halide tail strategy in human-readable form */
18975std::ostream &operator<<(std::ostream &stream, const TailStrategy &t);
18976
18977/** Emit a halide LoopLevel in human-readable form */
18978std::ostream &operator<<(std::ostream &stream, const LoopLevel &);
18979
18980struct Target;
18981/** Emit a halide Target in a human readable form */
18982std::ostream &operator<<(std::ostream &stream, const Target &);
18983
18984namespace Internal {
18985
18986struct AssociativePattern;
18987struct AssociativeOp;
18988
18989/** Emit a halide associative pattern on an output stream (such as std::cout)
18990 * in a human-readable form */
18991std::ostream &operator<<(std::ostream &stream, const AssociativePattern &);
18992
18993/** Emit a halide associative op on an output stream (such as std::cout) in a
18994 * human-readable form */
18995std::ostream &operator<<(std::ostream &stream, const AssociativeOp &);
18996
18997/** Emit a halide statement on an output stream (such as std::cout) in
18998 * a human-readable form */
18999std::ostream &operator<<(std::ostream &stream, const Stmt &);
19000
19001/** Emit a halide for loop type (vectorized, serial, etc) in a human
19002 * readable form */
19003std::ostream &operator<<(std::ostream &stream, const ForType &);
19004
19005/** Emit a horizontal vector reduction op in human-readable form. */
19006std::ostream &operator<<(std::ostream &stream, const VectorReduce::Operator &);
19007
19008/** Emit a halide name mangling value in a human readable format */
19009std::ostream &operator<<(std::ostream &stream, const NameMangling &);
19010
19011/** Emit a halide LoweredFunc in a human readable format */
19012std::ostream &operator<<(std::ostream &stream, const LoweredFunc &);
19013
19014/** Emit a halide linkage value in a human readable format */
19015std::ostream &operator<<(std::ostream &stream, const LinkageType &);
19016
19017/** Emit a halide dimension type in human-readable format */
19018std::ostream &operator<<(std::ostream &stream, const DimType &);
19019
19020struct Indentation {
19021 int indent;
19022};
19023std::ostream &operator<<(std::ostream &stream, const Indentation &);
19024
19025/** An IRVisitor that emits IR to the given output stream in a human
19026 * readable form. Can be subclassed if you want to modify the way in
19027 * which it prints.
19028 */
19029class IRPrinter : public IRVisitor {
19030public:
19031 /** Construct an IRPrinter pointed at a given output stream
19032 * (e.g. std::cout, or a std::ofstream) */
19033 explicit IRPrinter(std::ostream &);
19034
19035 /** emit an expression on the output stream */
19036 void print(const Expr &);
19037
19038 /** Emit an expression on the output stream without enclosing parens */
19039 void print_no_parens(const Expr &);
19040
19041 /** emit a statement on the output stream */
19042 void print(const Stmt &);
19043
19044 /** emit a comma delimited list of exprs, without any leading or
19045 * trailing punctuation. */
19046 void print_list(const std::vector<Expr> &exprs);
19047
19048 static void test();
19049
19050protected:
19051 Indentation get_indent() const {
19052 return Indentation{indent};
19053 }
19054
19055 /** The stream on which we're outputting */
19056 std::ostream &stream;
19057
19058 /** The current indentation level, useful for pretty-printing
19059 * statements */
19060 int indent = 0;
19061
19062 /** Certain expressions do not need parens around them, e.g. the
19063 * args to a call are already separated by commas and a
19064 * surrounding set of parens. */
19065 bool implicit_parens = false;
19066
19067 /** Either emits "(" or "", depending on the value of implicit_parens */
19068 void open();
19069
19070 /** Either emits ")" or "", depending on the value of implicit_parens */
19071 void close();
19072
19073 /** The symbols whose types can be inferred from values printed
19074 * already. */
19075 Scope<> known_type;
19076
19077 /** A helper for printing a chain of lets with line breaks */
19078 void print_lets(const Let *let);
19079
19080 void visit(const IntImm *) override;
19081 void visit(const UIntImm *) override;
19082 void visit(const FloatImm *) override;
19083 void visit(const StringImm *) override;
19084 void visit(const Cast *) override;
19085 void visit(const Variable *) override;
19086 void visit(const Add *) override;
19087 void visit(const Sub *) override;
19088 void visit(const Mul *) override;
19089 void visit(const Div *) override;
19090 void visit(const Mod *) override;
19091 void visit(const Min *) override;
19092 void visit(const Max *) override;
19093 void visit(const EQ *) override;
19094 void visit(const NE *) override;
19095 void visit(const LT *) override;
19096 void visit(const LE *) override;
19097 void visit(const GT *) override;
19098 void visit(const GE *) override;
19099 void visit(const And *) override;
19100 void visit(const Or *) override;
19101 void visit(const Not *) override;
19102 void visit(const Select *) override;
19103 void visit(const Load *) override;
19104 void visit(const Ramp *) override;
19105 void visit(const Broadcast *) override;
19106 void visit(const Call *) override;
19107 void visit(const Let *) override;
19108 void visit(const LetStmt *) override;
19109 void visit(const AssertStmt *) override;
19110 void visit(const ProducerConsumer *) override;
19111 void visit(const For *) override;
19112 void visit(const Acquire *) override;
19113 void visit(const Store *) override;
19114 void visit(const Provide *) override;
19115 void visit(const Allocate *) override;
19116 void visit(const Free *) override;
19117 void visit(const Realize *) override;
19118 void visit(const Block *) override;
19119 void visit(const Fork *) override;
19120 void visit(const IfThenElse *) override;
19121 void visit(const Evaluate *) override;
19122 void visit(const Shuffle *) override;
19123 void visit(const VectorReduce *) override;
19124 void visit(const Prefetch *) override;
19125 void visit(const Atomic *) override;
19126};
19127
19128} // namespace Internal
19129} // namespace Halide
19130
19131#endif
19132
19133namespace Halide {
19134
19135struct Argument;
19136class Module;
19137
19138namespace Internal {
19139
19140struct LoweredFunc;
19141
19142/** This class emits C++ code equivalent to a halide Stmt. It's
19143 * mostly the same as an IRPrinter, but it's wrapped in a function
19144 * definition, and some things are handled differently to be valid
19145 * C++.
19146 */
19147class CodeGen_C : public IRPrinter {
19148public:
19149 enum OutputKind {
19150 CHeader,
19151 CPlusPlusHeader,
19152 CImplementation,
19153 CPlusPlusImplementation,
19154 CExternDecl,
19155 CPlusPlusExternDecl,
19156 };
19157
19158 /** Initialize a C code generator pointing at a particular output
19159 * stream (e.g. a file, or std::cout) */
19160 CodeGen_C(std::ostream &dest,
19161 const Target &target,
19162 OutputKind output_kind = CImplementation,
19163 const std::string &include_guard = "");
19164 ~CodeGen_C() override;
19165
19166 /** Emit the declarations contained in the module as C code. */
19167 void compile(const Module &module);
19168
19169 /** The target we're generating code for */
19170 const Target &get_target() const {
19171 return target;
19172 }
19173
19174 static void test();
19175
19176protected:
19177 enum class IntegerSuffixStyle {
19178 PlainC = 0,
19179 OpenCL = 1,
19180 HLSL = 2
19181 };
19182
19183 /** How to emit 64-bit integer constants */
19184 IntegerSuffixStyle integer_suffix_style = IntegerSuffixStyle::PlainC;
19185
19186 /** Emit a declaration. */
19187 // @{
19188 virtual void compile(const LoweredFunc &func);
19189 virtual void compile(const Buffer<> &buffer);
19190 // @}
19191
19192 /** An ID for the most recently generated ssa variable */
19193 std::string id;
19194
19195 /** The target being generated for. */
19196 Target target;
19197
19198 /** Controls whether this instance is generating declarations or
19199 * definitions and whether the interface us extern "C" or C++. */
19200 OutputKind output_kind;
19201
19202 /** A cache of generated values in scope */
19203 std::map<std::string, std::string> cache;
19204
19205 /** Emit an expression as an assignment, then return the id of the
19206 * resulting var */
19207 std::string print_expr(const Expr &);
19208
19209 /** Like print_expr, but cast the Expr to the given Type */
19210 std::string print_cast_expr(const Type &, const Expr &);
19211
19212 /** Emit a statement */
19213 void print_stmt(const Stmt &);
19214
19215 void create_assertion(const std::string &id_cond, const Expr &message);
19216 void create_assertion(const Expr &cond, const Expr &message);
19217
19218 Expr scalarize_vector_reduce(const VectorReduce *op);
19219 enum AppendSpaceIfNeeded {
19220 DoNotAppendSpace,
19221 AppendSpace,
19222 };
19223
19224 /** Emit the C name for a halide type. If space_option is AppendSpace,
19225 * and there should be a space between the type and the next token,
19226 * one is appended. (This allows both "int foo" and "Foo *foo" to be
19227 * formatted correctly. Otherwise the latter is "Foo * foo".)
19228 */
19229 virtual std::string print_type(Type, AppendSpaceIfNeeded space_option = DoNotAppendSpace);
19230
19231 /** Emit a statement to reinterpret an expression as another type */
19232 virtual std::string print_reinterpret(Type, const Expr &);
19233
19234 /** Emit a version of a string that is a valid identifier in C (. is replaced with _) */
19235 virtual std::string print_name(const std::string &);
19236
19237 /** Add typedefs for vector types. Not needed for OpenCL, might
19238 * use different syntax for other C-like languages. */
19239 virtual void add_vector_typedefs(const std::set<Type> &vector_types);
19240
19241 /** Bottleneck to allow customization of calls to generic Extern/PureExtern calls. */
19242 virtual std::string print_extern_call(const Call *op);
19243
19244 /** Convert a vector Expr into a series of scalar Exprs, then reassemble into vector of original type. */
19245 std::string print_scalarized_expr(const Expr &e);
19246
19247 /** Emit an SSA-style assignment, and set id to the freshly generated name. Return id. */
19248 virtual std::string print_assignment(Type t, const std::string &rhs);
19249
19250 /** Emit free for the heap allocation. **/
19251 void print_heap_free(const std::string &alloc_name);
19252
19253 /** Return true if only generating an interface, which may be extern "C" or C++ */
19254 bool is_header() {
19255 return output_kind == CHeader ||
19256 output_kind == CPlusPlusHeader;
19257 }
19258
19259 /** Return true if only generating an interface, which may be extern "C" or C++ */
19260 bool is_extern_decl() {
19261 return output_kind == CExternDecl ||
19262 output_kind == CPlusPlusExternDecl;
19263 }
19264
19265 /** Return true if only generating an interface, which may be extern "C" or C++ */
19266 bool is_header_or_extern_decl() {
19267 return is_header() || is_extern_decl();
19268 }
19269
19270 /** Return true if generating C++ linkage. */
19271 bool is_c_plus_plus_interface() {
19272 return output_kind == CPlusPlusHeader ||
19273 output_kind == CPlusPlusImplementation ||
19274 output_kind == CPlusPlusExternDecl;
19275 }
19276
19277 /** Open a new C scope (i.e. throw in a brace, increase the indent) */
19278 void open_scope();
19279
19280 /** Close a C scope (i.e. throw in an end brace, decrease the indent) */
19281 void close_scope(const std::string &comment);
19282
19283 struct Allocation {
19284 Type type;
19285 };
19286
19287 /** Track the types of allocations to avoid unnecessary casts. */
19288 Scope<Allocation> allocations;
19289
19290 /** Track which allocations actually went on the heap. */
19291 Scope<> heap_allocations;
19292
19293 /** True if there is a void * __user_context parameter in the arguments. */
19294 bool have_user_context;
19295
19296 /** Track current calling convention scope. */
19297 bool extern_c_open;
19298
19299 /** True if at least one gpu-based for loop is used. */
19300 bool uses_gpu_for_loops;
19301
19302 /** Track which handle types have been forward-declared already. */
19303 std::set<const halide_handle_cplusplus_type *> forward_declared;
19304
19305 /** If the Type is a handle type, emit a forward-declaration for it
19306 * if we haven't already. */
19307 void forward_declare_type_if_needed(const Type &t);
19308
19309 void set_name_mangling_mode(NameMangling mode);
19310
19311 using IRPrinter::visit;
19312
19313 void visit(const Variable *) override;
19314 void visit(const IntImm *) override;
19315 void visit(const UIntImm *) override;
19316 void visit(const StringImm *) override;
19317 void visit(const FloatImm *) override;
19318 void visit(const Cast *) override;
19319 void visit(const Add *) override;
19320 void visit(const Sub *) override;
19321 void visit(const Mul *) override;
19322 void visit(const Div *) override;
19323 void visit(const Mod *) override;
19324 void visit(const Max *) override;
19325 void visit(const Min *) override;
19326 void visit(const EQ *) override;
19327 void visit(const NE *) override;
19328 void visit(const LT *) override;
19329 void visit(const LE *) override;
19330 void visit(const GT *) override;
19331 void visit(const GE *) override;
19332 void visit(const And *) override;
19333 void visit(const Or *) override;
19334 void visit(const Not *) override;
19335 void visit(const Call *) override;
19336 void visit(const Select *) override;
19337 void visit(const Load *) override;
19338 void visit(const Store *) override;
19339 void visit(const Let *) override;
19340 void visit(const LetStmt *) override;
19341 void visit(const AssertStmt *) override;
19342 void visit(const ProducerConsumer *) override;
19343 void visit(const For *) override;
19344 void visit(const Ramp *) override;
19345 void visit(const Broadcast *) override;
19346 void visit(const Provide *) override;
19347 void visit(const Allocate *) override;
19348 void visit(const Free *) override;
19349 void visit(const Realize *) override;
19350 void visit(const IfThenElse *) override;
19351 void visit(const Evaluate *) override;
19352 void visit(const Shuffle *) override;
19353 void visit(const Prefetch *) override;
19354 void visit(const Fork *) override;
19355 void visit(const Acquire *) override;
19356 void visit(const Atomic *) override;
19357 void visit(const VectorReduce *) override;
19358
19359 void visit_binop(Type t, const Expr &a, const Expr &b, const char *op);
19360 void visit_relop(Type t, const Expr &a, const Expr &b, const char *scalar_op, const char *vector_op);
19361
19362 template<typename T>
19363 static std::string with_sep(const std::vector<T> &v, const std::string &sep) {
19364 std::ostringstream o;
19365 for (size_t i = 0; i < v.size(); ++i) {
19366 if (i > 0) {
19367 o << sep;
19368 }
19369 o << v[i];
19370 }
19371 return o.str();
19372 }
19373
19374 template<typename T>
19375 static std::string with_commas(const std::vector<T> &v) {
19376 return with_sep<T>(v, ", ");
19377 }
19378
19379 /** Are we inside an atomic node that uses mutex locks?
19380 This is used for detecting deadlocks from nested atomics. */
19381 bool inside_atomic_mutex_node;
19382
19383 /** Emit atomic store instructions? */
19384 bool emit_atomic_stores;
19385
19386 /** true if add_vector_typedefs() has been called. */
19387 bool using_vector_typedefs;
19388};
19389
19390} // namespace Internal
19391} // namespace Halide
19392
19393#endif
19394#ifndef HALIDE_CODEGEN_D3D12_COMPUTE_DEV_H
19395#define HALIDE_CODEGEN_D3D12_COMPUTE_DEV_H
19396
19397/** \file
19398 * Defines the code-generator for producing D3D12-compatible HLSL kernel code
19399 */
19400
19401#include <memory>
19402
19403namespace Halide {
19404
19405struct Target;
19406
19407namespace Internal {
19408
19409struct CodeGen_GPU_Dev;
19410
19411std::unique_ptr<CodeGen_GPU_Dev> new_CodeGen_D3D12Compute_Dev(const Target &target);
19412
19413} // namespace Internal
19414} // namespace Halide
19415
19416#endif
19417#ifndef HALIDE_CODEGEN_GPU_DEV_H
19418#define HALIDE_CODEGEN_GPU_DEV_H
19419
19420/** \file
19421 * Defines the code-generator interface for producing GPU device code
19422 */
19423#include <string>
19424#include <vector>
19425
19426#ifndef HALIDE_DEVICE_ARGUMENT_H
19427#define HALIDE_DEVICE_ARGUMENT_H
19428
19429/** \file
19430 * Defines helpers for passing arguments to separate devices, such as GPUs.
19431 */
19432#include <string>
19433
19434
19435namespace Halide {
19436namespace Internal {
19437
19438/** A DeviceArgument looks similar to an Halide::Argument, but has behavioral
19439 * differences that make it specific to the GPU pipeline; the fact that
19440 * neither is-a nor has-a Halide::Argument is deliberate. In particular, note
19441 * that a Halide::Argument that is a buffer can be read or write, but not both,
19442 * while a DeviceArgument that is a buffer can be read *and* write for some GPU
19443 * backends. */
19444struct DeviceArgument {
19445 /** The name of the argument */
19446 std::string name;
19447
19448 /** An argument is either a primitive type (for parameters), or a
19449 * buffer pointer.
19450 *
19451 * If is_buffer == false, then type fully encodes the expected type
19452 * of the scalar argument.
19453 *
19454 * If is_buffer == true, then type.bytes() should be used to determine
19455 * elem_size of the buffer; additionally, type.code *should* reflect
19456 * the expected interpretation of the buffer data (e.g. float vs int),
19457 * but there is no runtime enforcement of this at present.
19458 */
19459 bool is_buffer = false;
19460
19461 /** If is_buffer == true and memory_type == GPUTexture, this argument should be
19462 * passed and accessed through texture sampler operations instead of
19463 * directly as a memory array
19464 */
19465 MemoryType memory_type = MemoryType::Auto;
19466
19467 /** If is_buffer is true, this is the dimensionality of the buffer.
19468 * If is_buffer is false, this value is ignored (and should always be set to zero) */
19469 uint8_t dimensions = 0;
19470
19471 /** If this is a scalar parameter, then this is its type.
19472 *
19473 * If this is a buffer parameter, this is used to determine elem_size
19474 * of the halide_buffer_t.
19475 *
19476 * Note that type.lanes() should always be 1 here. */
19477 Type type;
19478
19479 /** The static size of the argument if known, or zero otherwise. */
19480 size_t size = 0;
19481
19482 /** The index of the first element of the argument when packed into a wider
19483 * type, such as packing scalar floats into vec4 for GLSL. */
19484 size_t packed_index = 0;
19485
19486 /** For buffers, these two variables can be used to specify whether the
19487 * buffer is read or written. By default, we assume that the argument
19488 * buffer is read-write and set both flags. */
19489 bool read = false;
19490 bool write = false;
19491
19492 /** Alignment information for integer parameters. */
19493 ModulusRemainder alignment;
19494
19495 DeviceArgument() = default;
19496
19497 DeviceArgument(const std::string &_name,
19498 bool _is_buffer,
19499 MemoryType _mem,
19500 Type _type,
19501 uint8_t _dimensions,
19502 size_t _size = 0)
19503 : name(_name),
19504 is_buffer(_is_buffer),
19505 memory_type(_mem),
19506 dimensions(_dimensions),
19507 type(_type),
19508 size(_size),
19509
19510 read(_is_buffer),
19511 write(_is_buffer) {
19512 }
19513};
19514
19515/** A Closure modified to inspect GPU-specific memory accesses, and
19516 * produce a vector of DeviceArgument objects. */
19517class HostClosure : public Closure {
19518public:
19519 HostClosure(const Stmt &s, const std::string &loop_variable = "");
19520
19521 /** Get a description of the captured arguments. */
19522 std::vector<DeviceArgument> arguments();
19523
19524protected:
19525 using Internal::Closure::visit;
19526 void visit(const For *loop) override;
19527 void visit(const Call *op) override;
19528};
19529
19530} // namespace Internal
19531} // namespace Halide
19532
19533#endif
19534
19535namespace Halide {
19536namespace Internal {
19537
19538/** A code generator that emits GPU code from a given Halide stmt. */
19539struct CodeGen_GPU_Dev {
19540 virtual ~CodeGen_GPU_Dev();
19541
19542 /** Compile a GPU kernel into the module. This may be called many times
19543 * with different kernels, which will all be accumulated into a single
19544 * source module shared by a given Halide pipeline. */
19545 virtual void add_kernel(Stmt stmt,
19546 const std::string &name,
19547 const std::vector<DeviceArgument> &args) = 0;
19548
19549 /** (Re)initialize the GPU kernel module. This is separate from compile,
19550 * since a GPU device module will often have many kernels compiled into it
19551 * for a single pipeline. */
19552 virtual void init_module() = 0;
19553
19554 virtual std::vector<char> compile_to_src() = 0;
19555
19556 virtual std::string get_current_kernel_name() = 0;
19557
19558 virtual void dump() = 0;
19559
19560 /** This routine returns the GPU API name that is combined into
19561 * runtime routine names to ensure each GPU API has a unique
19562 * name.
19563 */
19564 virtual std::string api_unique_name() = 0;
19565
19566 /** Returns the specified name transformed by the variable naming rules
19567 * for the GPU language backend. Used to determine the name of a parameter
19568 * during host codegen. */
19569 virtual std::string print_gpu_name(const std::string &name) = 0;
19570
19571 /** Allows the GPU device specific code to request halide_type_t
19572 * values to be passed to the kernel_run routine rather than just
19573 * argument type sizes.
19574 */
19575 virtual bool kernel_run_takes_types() const {
19576 return false;
19577 }
19578
19579 static bool is_gpu_var(const std::string &name);
19580 static bool is_gpu_block_var(const std::string &name);
19581 static bool is_gpu_thread_var(const std::string &name);
19582
19583 /** Checks if expr is block uniform, i.e. does not depend on a thread
19584 * var. */
19585 static bool is_block_uniform(const Expr &expr);
19586 /** Checks if the buffer is a candidate for constant storage. Most
19587 * GPUs (APIs) support a constant memory storage class that cannot be
19588 * written to and performs well for block uniform accesses. A buffer is a
19589 * candidate for constant storage if it is never written to, and loads are
19590 * uniform within the workgroup. */
19591 static bool is_buffer_constant(const Stmt &kernel, const std::string &buffer);
19592
19593 /** An mask describing which type of memory fence to use for the gpu_thread_barrier()
19594 * intrinsic. Not all GPUs APIs support all types.
19595 */
19596 enum MemoryFenceType {
19597 None = 0, // No fence required (just a sync)
19598 Device = 1, // Device/global memory fence
19599 Shared = 2 // Threadgroup/shared memory fence
19600 };
19601};
19602
19603} // namespace Internal
19604} // namespace Halide
19605
19606#endif
19607#ifndef HALIDE_CODEGEN_INTERNAL_H
19608#define HALIDE_CODEGEN_INTERNAL_H
19609
19610/** \file
19611 *
19612 * Defines functionality that's useful to multiple target-specific
19613 * CodeGen paths, but shouldn't live in CodeGen_LLVM.h (because that's the
19614 * front-end-facing interface to CodeGen).
19615 */
19616
19617#include <memory>
19618#include <string>
19619
19620
19621namespace llvm {
19622class ConstantFolder;
19623class ElementCount;
19624class Function;
19625class IRBuilderDefaultInserter;
19626class LLVMContext;
19627class Module;
19628class StructType;
19629class TargetMachine;
19630class TargetOptions;
19631class Type;
19632class Value;
19633template<typename, typename>
19634class IRBuilder;
19635} // namespace llvm
19636
19637namespace Halide {
19638
19639struct Target;
19640
19641namespace Internal {
19642
19643/** The llvm type of a struct containing all of the externally referenced state of a Closure. */
19644llvm::StructType *build_closure_type(const Closure &closure, llvm::StructType *halide_buffer_t_type, llvm::LLVMContext *context);
19645
19646/** Emit code that builds a struct containing all the externally
19647 * referenced state. Requires you to pass it a type and struct to fill in,
19648 * a scope to retrieve the llvm values from and a builder to place
19649 * the packing code. */
19650void pack_closure(llvm::StructType *type,
19651 llvm::Value *dst,
19652 const Closure &closure,
19653 const Scope<llvm::Value *> &src,
19654 llvm::StructType *halide_buffer_t_type,
19655 llvm::IRBuilder<llvm::ConstantFolder, llvm::IRBuilderDefaultInserter> *builder);
19656
19657/** Emit code that unpacks a struct containing all the externally
19658 * referenced state into a symbol table. Requires you to pass it a
19659 * state struct type and value, a scope to fill, and a builder to place the
19660 * unpacking code. */
19661void unpack_closure(const Closure &closure,
19662 Scope<llvm::Value *> &dst,
19663 llvm::StructType *type,
19664 llvm::Value *src,
19665 llvm::IRBuilder<llvm::ConstantFolder, llvm::IRBuilderDefaultInserter> *builder);
19666
19667/** Get the llvm type equivalent to a given halide type */
19668llvm::Type *llvm_type_of(llvm::LLVMContext *context, Halide::Type t);
19669
19670/** Get the number of elements in an llvm vector type, or return 1 if
19671 * it's not a vector type. */
19672int get_vector_num_elements(llvm::Type *);
19673
19674/** Get the scalar type of an llvm vector type. Returns the argument
19675 * if it's not a vector type. */
19676llvm::Type *get_vector_element_type(llvm::Type *);
19677
19678llvm::ElementCount element_count(int e);
19679
19680llvm::Type *get_vector_type(llvm::Type *, int);
19681
19682/** Which built-in functions require a user-context first argument? */
19683bool function_takes_user_context(const std::string &name);
19684
19685/** Given a size (in bytes), return True if the allocation size can fit
19686 * on the stack; otherwise, return False. This routine asserts if size is
19687 * non-positive. */
19688bool can_allocation_fit_on_stack(int64_t size);
19689
19690/** Does a {div/mod}_round_to_zero using binary long division for int/uint.
19691 * max_abs is the maximum absolute value of (a/b).
19692 * Returns the pair {div_round_to_zero, mod_round_to_zero}. */
19693std::pair<Expr, Expr> long_div_mod_round_to_zero(const Expr &a, const Expr &b,
19694 const uint64_t *max_abs = nullptr);
19695
19696/** Given a Halide Euclidean division/mod operation, do constant optimizations
19697 * and possibly call lower_euclidean_div/lower_euclidean_mod if necessary.
19698 * Can introduce mulhi_shr and sorted_avg intrinsics as well as those from the
19699 * lower_euclidean_ operation -- div_round_to_zero or mod_round_to_zero. */
19700///@{
19701Expr lower_int_uint_div(const Expr &a, const Expr &b);
19702Expr lower_int_uint_mod(const Expr &a, const Expr &b);
19703///@}
19704
19705/** Given a Halide Euclidean division/mod operation, define it in terms of
19706 * div_round_to_zero or mod_round_to_zero. */
19707///@{
19708Expr lower_euclidean_div(Expr a, Expr b);
19709Expr lower_euclidean_mod(Expr a, Expr b);
19710///@}
19711
19712/** Given a Halide shift operation with a signed shift amount (may be negative), define
19713 * an equivalent expression using only shifts by unsigned amounts. */
19714///@{
19715Expr lower_signed_shift_left(const Expr &a, const Expr &b);
19716Expr lower_signed_shift_right(const Expr &a, const Expr &b);
19717///@}
19718
19719/** Reduce a mux intrinsic to a select tree */
19720Expr lower_mux(const Call *mux);
19721
19722/** Given an llvm::Module, set llvm:TargetOptions, cpu and attr information */
19723void get_target_options(const llvm::Module &module, llvm::TargetOptions &options, std::string &mcpu, std::string &mattrs);
19724
19725/** Given two llvm::Modules, clone target options from one to the other */
19726void clone_target_options(const llvm::Module &from, llvm::Module &to);
19727
19728/** Given an llvm::Module, get or create an llvm:TargetMachine */
19729std::unique_ptr<llvm::TargetMachine> make_target_machine(const llvm::Module &module);
19730
19731/** Set the appropriate llvm Function attributes given a Target. */
19732void set_function_attributes_for_target(llvm::Function *, const Target &);
19733
19734/** Save a copy of the llvm IR currently represented by the module as
19735 * data in the __LLVM,__bitcode section. Emulates clang's
19736 * -fembed-bitcode flag and is useful to satisfy Apple's bitcode
19737 * inclusion requirements. */
19738void embed_bitcode(llvm::Module *M, const std::string &halide_command);
19739
19740} // namespace Internal
19741} // namespace Halide
19742
19743#endif
19744#ifndef HALIDE_CODEGEN_LLVM_H
19745#define HALIDE_CODEGEN_LLVM_H
19746
19747/** \file
19748 *
19749 * Defines the base-class for all architecture-specific code
19750 * generators that use llvm.
19751 */
19752
19753namespace llvm {
19754class Value;
19755class Module;
19756class Function;
19757class FunctionType;
19758class IRBuilderDefaultInserter;
19759class ConstantFolder;
19760template<typename, typename>
19761class IRBuilder;
19762class LLVMContext;
19763class Type;
19764class StructType;
19765class Instruction;
19766class CallInst;
19767class ExecutionEngine;
19768class AllocaInst;
19769class Constant;
19770class Triple;
19771class MDNode;
19772class NamedMDNode;
19773class DataLayout;
19774class BasicBlock;
19775class GlobalVariable;
19776} // namespace llvm
19777
19778#include <map>
19779#include <memory>
19780#include <string>
19781#include <vector>
19782
19783
19784namespace Halide {
19785
19786struct ExternSignature;
19787
19788namespace Internal {
19789
19790/** A code generator abstract base class. Actual code generators
19791 * (e.g. CodeGen_X86) inherit from this. This class is responsible
19792 * for taking a Halide Stmt and producing llvm bitcode, machine
19793 * code in an object file, or machine code accessible through a
19794 * function pointer.
19795 */
19796class CodeGen_LLVM : public IRVisitor {
19797public:
19798 /** Create an instance of CodeGen_LLVM suitable for the target. */
19799 static std::unique_ptr<CodeGen_LLVM> new_for_target(const Target &target, llvm::LLVMContext &context);
19800
19801 ~CodeGen_LLVM() override;
19802
19803 /** Takes a halide Module and compiles it to an llvm Module. */
19804 virtual std::unique_ptr<llvm::Module> compile(const Module &module);
19805
19806 /** The target we're generating code for */
19807 const Target &get_target() const {
19808 return target;
19809 }
19810
19811 /** Tell the code generator which LLVM context to use. */
19812 void set_context(llvm::LLVMContext &context);
19813
19814 /** Initialize internal llvm state for the enabled targets. */
19815 static void initialize_llvm();
19816
19817 static std::unique_ptr<llvm::Module> compile_trampolines(
19818 const Target &target,
19819 llvm::LLVMContext &context,
19820 const std::string &suffix,
19821 const std::vector<std::pair<std::string, ExternSignature>> &externs);
19822
19823 size_t get_requested_alloca_total() const {
19824 return requested_alloca_total;
19825 }
19826
19827protected:
19828 CodeGen_LLVM(const Target &t);
19829
19830 /** Compile a specific halide declaration into the llvm Module. */
19831 // @{
19832 virtual void compile_func(const LoweredFunc &func, const std::string &simple_name, const std::string &extern_name);
19833 virtual void compile_buffer(const Buffer<> &buffer);
19834 // @}
19835
19836 /** Helper functions for compiling Halide functions to llvm
19837 * functions. begin_func performs all the work necessary to begin
19838 * generating code for a function with a given argument list with
19839 * the IRBuilder. A call to begin_func should be a followed by a
19840 * call to end_func with the same arguments, to generate the
19841 * appropriate cleanup code. */
19842 // @{
19843 virtual void begin_func(LinkageType linkage, const std::string &simple_name,
19844 const std::string &extern_name, const std::vector<LoweredArgument> &args);
19845 virtual void end_func(const std::vector<LoweredArgument> &args);
19846 // @}
19847
19848 /** What should be passed as -mcpu, -mattrs, and related for
19849 * compilation. The architecture-specific code generator should
19850 * define these. */
19851 // @{
19852 virtual std::string mcpu() const = 0;
19853 virtual std::string mattrs() const = 0;
19854 virtual std::string mabi() const;
19855 virtual bool use_soft_float_abi() const = 0;
19856 virtual bool use_pic() const;
19857 // @}
19858
19859 /** Should indexing math be promoted to 64-bit on platforms with
19860 * 64-bit pointers? */
19861 virtual bool promote_indices() const {
19862 return true;
19863 }
19864
19865 /** What's the natural vector bit-width to use for loads, stores, etc. */
19866 virtual int native_vector_bits() const = 0;
19867
19868 /** Return the type in which arithmetic should be done for the
19869 * given storage type. */
19870 virtual Type upgrade_type_for_arithmetic(const Type &) const;
19871
19872 /** Return the type that a given Halide type should be
19873 * stored/loaded from memory as. */
19874 virtual Type upgrade_type_for_storage(const Type &) const;
19875
19876 /** Return the type that a Halide type should be passed in and out
19877 * of functions as. */
19878 virtual Type upgrade_type_for_argument_passing(const Type &) const;
19879
19880 std::unique_ptr<llvm::Module> module;
19881 llvm::Function *function;
19882 llvm::LLVMContext *context;
19883 llvm::IRBuilder<llvm::ConstantFolder, llvm::IRBuilderDefaultInserter> *builder;
19884 llvm::Value *value;
19885 llvm::MDNode *very_likely_branch;
19886 llvm::MDNode *default_fp_math_md;
19887 llvm::MDNode *strict_fp_math_md;
19888 std::vector<LoweredArgument> current_function_args;
19889 //@}
19890
19891 /** The target we're generating code for */
19892 Halide::Target target;
19893
19894 /** Grab all the context specific internal state. */
19895 virtual void init_context();
19896 /** Initialize the CodeGen_LLVM internal state to compile a fresh
19897 * module. This allows reuse of one CodeGen_LLVM object to compiled
19898 * multiple related modules (e.g. multiple device kernels). */
19899 virtual void init_module();
19900
19901 /** Add external_code entries to llvm module. */
19902 void add_external_code(const Module &halide_module);
19903
19904 /** Run all of llvm's optimization passes on the module. */
19905 void optimize_module();
19906
19907 /** Add an entry to the symbol table, hiding previous entries with
19908 * the same name. Call this when new values come into scope. */
19909 void sym_push(const std::string &name, llvm::Value *value);
19910
19911 /** Remove an entry for the symbol table, revealing any previous
19912 * entries with the same name. Call this when values go out of
19913 * scope. */
19914 void sym_pop(const std::string &name);
19915
19916 /** Fetch an entry from the symbol table. If the symbol is not
19917 * found, it either errors out (if the second arg is true), or
19918 * returns nullptr. */
19919 llvm::Value *sym_get(const std::string &name,
19920 bool must_succeed = true) const;
19921
19922 /** Test if an item exists in the symbol table. */
19923 bool sym_exists(const std::string &name) const;
19924
19925 /** Given a Halide ExternSignature, return the equivalent llvm::FunctionType. */
19926 llvm::FunctionType *signature_to_type(const ExternSignature &signature);
19927
19928 /** Some useful llvm types */
19929 // @{
19930 llvm::Type *void_t, *i1_t, *i8_t, *i16_t, *i32_t, *i64_t, *f16_t, *f32_t, *f64_t;
19931 llvm::StructType *halide_buffer_t_type,
19932 *type_t_type,
19933 *dimension_t_type,
19934 *metadata_t_type,
19935 *argument_t_type,
19936 *scalar_value_t_type,
19937 *device_interface_t_type,
19938 *pseudostack_slot_t_type,
19939 *semaphore_t_type,
19940 *semaphore_acquire_t_type,
19941 *parallel_task_t_type;
19942
19943 // @}
19944
19945 /** Some wildcard variables used for peephole optimizations in
19946 * subclasses */
19947 // @{
19948 Expr wild_u1x_, wild_i8x_, wild_u8x_, wild_i16x_, wild_u16x_;
19949 Expr wild_i32x_, wild_u32x_, wild_i64x_, wild_u64x_;
19950 Expr wild_f32x_, wild_f64x_;
19951
19952 // Wildcards for scalars.
19953 Expr wild_u1_, wild_i8_, wild_u8_, wild_i16_, wild_u16_;
19954 Expr wild_i32_, wild_u32_, wild_i64_, wild_u64_;
19955 Expr wild_f32_, wild_f64_;
19956 // @}
19957
19958 /** Emit code that evaluates an expression, and return the llvm
19959 * representation of the result of the expression. */
19960 llvm::Value *codegen(const Expr &);
19961
19962 /** Emit code that runs a statement. */
19963 void codegen(const Stmt &);
19964
19965 /** Codegen a vector Expr by codegenning each lane and combining. */
19966 void scalarize(const Expr &);
19967
19968 /** Some destructors should always be called. Others should only
19969 * be called if the pipeline is exiting with an error code. */
19970 enum DestructorType { Always,
19971 OnError,
19972 OnSuccess };
19973
19974 /* Call this at the location of object creation to register how an
19975 * object should be destroyed. This does three things:
19976 * 1) Emits code here that puts the object in a unique
19977 * null-initialized stack slot
19978 * 2) Adds an instruction to the destructor block that calls the
19979 * destructor on that stack slot if it's not null.
19980 * 3) Returns that stack slot, so you can neuter the destructor
19981 * (by storing null to the stack slot) or destroy the object early
19982 * (by calling trigger_destructor).
19983 */
19984 llvm::Value *register_destructor(llvm::Function *destructor_fn, llvm::Value *obj, DestructorType when);
19985
19986 /** Call a destructor early. Pass in the value returned by register destructor. */
19987 void trigger_destructor(llvm::Function *destructor_fn, llvm::Value *stack_slot);
19988
19989 /** Retrieves the block containing the error handling
19990 * code. Creates it if it doesn't already exist for this
19991 * function. */
19992 llvm::BasicBlock *get_destructor_block();
19993
19994 /** Codegen an assertion. If false, returns the error code (if not
19995 * null), or evaluates and returns the message, which must be an
19996 * Int(32) expression. */
19997 // @{
19998 void create_assertion(llvm::Value *condition, const Expr &message, llvm::Value *error_code = nullptr);
19999 // @}
20000
20001 /** Codegen a block of asserts with pure conditions */
20002 void codegen_asserts(const std::vector<const AssertStmt *> &asserts);
20003
20004 /** Codegen a call to do_parallel_tasks */
20005 struct ParallelTask {
20006 Stmt body;
20007 struct SemAcquire {
20008 Expr semaphore;
20009 Expr count;
20010 };
20011 std::vector<SemAcquire> semaphores;
20012 std::string loop_var;
20013 Expr min, extent;
20014 Expr serial;
20015 std::string name;
20016 };
20017 int task_depth;
20018 void get_parallel_tasks(const Stmt &s, std::vector<ParallelTask> &tasks, std::pair<std::string, int> prefix);
20019 void do_parallel_tasks(const std::vector<ParallelTask> &tasks);
20020 void do_as_parallel_task(const Stmt &s);
20021
20022 /** Return the the pipeline with the given error code. Will run
20023 * the destructor block. */
20024 void return_with_error_code(llvm::Value *error_code);
20025
20026 /** Put a string constant in the module as a global variable and return a pointer to it. */
20027 llvm::Constant *create_string_constant(const std::string &str);
20028
20029 /** Put a binary blob in the module as a global variable and return a pointer to it. */
20030 llvm::Constant *create_binary_blob(const std::vector<char> &data, const std::string &name, bool constant = true);
20031
20032 /** Widen an llvm scalar into an llvm vector with the given number of lanes. */
20033 llvm::Value *create_broadcast(llvm::Value *, int lanes);
20034
20035 /** Generate a pointer into a named buffer at a given index, of a
20036 * given type. The index counts according to the scalar type of
20037 * the type passed in. */
20038 // @{
20039 llvm::Value *codegen_buffer_pointer(const std::string &buffer, Type type, llvm::Value *index);
20040 llvm::Value *codegen_buffer_pointer(const std::string &buffer, Type type, Expr index);
20041 llvm::Value *codegen_buffer_pointer(llvm::Value *base_address, Type type, Expr index);
20042 llvm::Value *codegen_buffer_pointer(llvm::Value *base_address, Type type, llvm::Value *index);
20043 // @}
20044
20045 /** Turn a Halide Type into an llvm::Value representing a constant halide_type_t */
20046 llvm::Value *make_halide_type_t(const Type &);
20047
20048 /** Mark a load or store with type-based-alias-analysis metadata
20049 * so that llvm knows it can reorder loads and stores across
20050 * different buffers */
20051 void add_tbaa_metadata(llvm::Instruction *inst, std::string buffer, const Expr &index);
20052
20053 /** Get a unique name for the actual block of memory that an
20054 * allocate node uses. Used so that alias analysis understands
20055 * when multiple Allocate nodes shared the same memory. */
20056 virtual std::string get_allocation_name(const std::string &n) {
20057 return n;
20058 }
20059
20060 using IRVisitor::visit;
20061
20062 /** Generate code for various IR nodes. These can be overridden by
20063 * architecture-specific code to perform peephole
20064 * optimizations. The result of each is stored in \ref value */
20065 // @{
20066 void visit(const IntImm *) override;
20067 void visit(const UIntImm *) override;
20068 void visit(const FloatImm *) override;
20069 void visit(const StringImm *) override;
20070 void visit(const Cast *) override;
20071 void visit(const Variable *) override;
20072 void visit(const Add *) override;
20073 void visit(const Sub *) override;
20074 void visit(const Mul *) override;
20075 void visit(const Div *) override;
20076 void visit(const Mod *) override;
20077 void visit(const Min *) override;
20078 void visit(const Max *) override;
20079 void visit(const EQ *) override;
20080 void visit(const NE *) override;
20081 void visit(const LT *) override;
20082 void visit(const LE *) override;
20083 void visit(const GT *) override;
20084 void visit(const GE *) override;
20085 void visit(const And *) override;
20086 void visit(const Or *) override;
20087 void visit(const Not *) override;
20088 void visit(const Select *) override;
20089 void visit(const Load *) override;
20090 void visit(const Ramp *) override;
20091 void visit(const Broadcast *) override;
20092 void visit(const Call *) override;
20093 void visit(const Let *) override;
20094 void visit(const LetStmt *) override;
20095 void visit(const AssertStmt *) override;
20096 void visit(const ProducerConsumer *) override;
20097 void visit(const For *) override;
20098 void visit(const Acquire *) override;
20099 void visit(const Store *) override;
20100 void visit(const Block *) override;
20101 void visit(const Fork *) override;
20102 void visit(const IfThenElse *) override;
20103 void visit(const Evaluate *) override;
20104 void visit(const Shuffle *) override;
20105 void visit(const VectorReduce *) override;
20106 void visit(const Prefetch *) override;
20107 void visit(const Atomic *) override;
20108 // @}
20109
20110 /** Generate code for an allocate node. It has no default
20111 * implementation - it must be handled in an architecture-specific
20112 * way. */
20113 void visit(const Allocate *) override = 0;
20114
20115 /** Generate code for a free node. It has no default
20116 * implementation and must be handled in an architecture-specific
20117 * way. */
20118 void visit(const Free *) override = 0;
20119
20120 /** These IR nodes should have been removed during
20121 * lowering. CodeGen_LLVM will error out if they are present */
20122 // @{
20123 void visit(const Provide *) override;
20124 void visit(const Realize *) override;
20125 // @}
20126
20127 /** If we have to bail out of a pipeline midway, this should
20128 * inject the appropriate target-specific cleanup code. */
20129 virtual void prepare_for_early_exit() {
20130 }
20131
20132 /** Get the llvm type equivalent to the given halide type in the
20133 * current context. */
20134 virtual llvm::Type *llvm_type_of(const Type &) const;
20135
20136 /** Perform an alloca at the function entrypoint. Will be cleaned
20137 * on function exit. */
20138 llvm::Value *create_alloca_at_entry(llvm::Type *type, int n,
20139 bool zero_initialize = false,
20140 const std::string &name = "");
20141
20142 /** A (very) conservative guess at the size of all alloca() storage requested
20143 * (including alignment padding). It's currently meant only to be used as
20144 * a very coarse way to ensure there is enough stack space when testing
20145 * on the WebAssembly backend.
20146 *
20147 * It is *not* meant to be a useful proxy for "stack space needed", for a
20148 * number of reasons:
20149 * - allocas with non-overlapping lifetimes will share space
20150 * - on some backends, LLVM may promote register-sized allocas into registers
20151 * - while this accounts for alloca() calls we know about, it doesn't attempt
20152 * to account for stack spills, function call overhead, etc.
20153 */
20154 size_t requested_alloca_total = 0;
20155
20156 /** Which buffers came in from the outside world (and so we can't
20157 * guarantee their alignment) */
20158 std::set<std::string> external_buffer;
20159
20160 /** The user_context argument. May be a constant null if the
20161 * function is being compiled without a user context. */
20162 llvm::Value *get_user_context() const;
20163
20164 /** Implementation of the intrinsic call to
20165 * interleave_vectors. This implementation allows for interleaving
20166 * an arbitrary number of vectors.*/
20167 virtual llvm::Value *interleave_vectors(const std::vector<llvm::Value *> &);
20168
20169 /** Description of an intrinsic function overload. Overloads are resolved
20170 * using both argument and return types. The scalar types of the arguments
20171 * and return type must match exactly for an overload resolution to succeed. */
20172 struct Intrinsic {
20173 Type result_type;
20174 std::vector<Type> arg_types;
20175 llvm::Function *impl;
20176
20177 Intrinsic(Type result_type, std::vector<Type> arg_types, llvm::Function *impl)
20178 : result_type(result_type), arg_types(std::move(arg_types)), impl(impl) {
20179 }
20180 };
20181 /** Mapping of intrinsic functions to the various overloads implementing it. */
20182 std::map<std::string, std::vector<Intrinsic>> intrinsics;
20183
20184 /** Get an LLVM intrinsic declaration. If it doesn't exist, it will be created. */
20185 llvm::Function *get_llvm_intrin(const Type &ret_type, const std::string &name, const std::vector<Type> &arg_types, bool scalars_are_vectors = false);
20186 llvm::Function *get_llvm_intrin(llvm::Type *ret_type, const std::string &name, const std::vector<llvm::Type *> &arg_types);
20187 /** Declare an intrinsic function that participates in overload resolution. */
20188 llvm::Function *declare_intrin_overload(const std::string &name, const Type &ret_type, const std::string &impl_name, std::vector<Type> arg_types, bool scalars_are_vectors = false);
20189 void declare_intrin_overload(const std::string &name, const Type &ret_type, llvm::Function *impl, std::vector<Type> arg_types);
20190 /** Call an overloaded intrinsic function. Returns nullptr if no suitable overload is found. */
20191 llvm::Value *call_overloaded_intrin(const Type &result_type, const std::string &name, const std::vector<Expr> &args);
20192
20193 /** Generate a call to a vector intrinsic or runtime inlined
20194 * function. The arguments are sliced up into vectors of the width
20195 * given by 'intrin_lanes', the intrinsic is called on each
20196 * piece, then the results (if any) are concatenated back together
20197 * into the original type 't'. For the version that takes an
20198 * llvm::Type *, the type may be void, so the vector width of the
20199 * arguments must be specified explicitly as
20200 * 'called_lanes'. */
20201 // @{
20202 llvm::Value *call_intrin(const Type &t, int intrin_lanes,
20203 const std::string &name, std::vector<Expr>);
20204 llvm::Value *call_intrin(const Type &t, int intrin_lanes,
20205 llvm::Function *intrin, std::vector<Expr>);
20206 llvm::Value *call_intrin(llvm::Type *t, int intrin_lanes,
20207 const std::string &name, std::vector<llvm::Value *>);
20208 llvm::Value *call_intrin(llvm::Type *t, int intrin_lanes,
20209 llvm::Function *intrin, std::vector<llvm::Value *>);
20210 // @}
20211
20212 /** Take a slice of lanes out of an llvm vector. Pads with undefs
20213 * if you ask for more lanes than the vector has. */
20214 virtual llvm::Value *slice_vector(llvm::Value *vec, int start, int extent);
20215
20216 /** Concatenate a bunch of llvm vectors. Must be of the same type. */
20217 virtual llvm::Value *concat_vectors(const std::vector<llvm::Value *> &);
20218
20219 /** Create an LLVM shuffle vectors instruction. */
20220 virtual llvm::Value *shuffle_vectors(llvm::Value *a, llvm::Value *b,
20221 const std::vector<int> &indices);
20222 /** Shorthand for shuffling a vector with an undef vector. */
20223 llvm::Value *shuffle_vectors(llvm::Value *v, const std::vector<int> &indices);
20224
20225 /** Go looking for a vector version of a runtime function. Will
20226 * return the best match. Matches in the following order:
20227 *
20228 * 1) The requested vector width.
20229 *
20230 * 2) The width which is the smallest power of two
20231 * greater than or equal to the vector width.
20232 *
20233 * 3) All the factors of 2) greater than one, in decreasing order.
20234 *
20235 * 4) The smallest power of two not yet tried.
20236 *
20237 * So for a 5-wide vector, it tries: 5, 8, 4, 2, 16.
20238 *
20239 * If there's no match, returns (nullptr, 0).
20240 */
20241 std::pair<llvm::Function *, int> find_vector_runtime_function(const std::string &name, int lanes);
20242
20243 virtual bool supports_atomic_add(const Type &t) const;
20244
20245 /** Compile a horizontal reduction that starts with an explicit
20246 * initial value. There are lots of complex ways to peephole
20247 * optimize this pattern, especially with the proliferation of
20248 * dot-product instructions, and they can usefully share logic
20249 * across backends. */
20250 virtual void codegen_vector_reduce(const VectorReduce *op, const Expr &init);
20251
20252 /** Are we inside an atomic node that uses mutex locks?
20253 This is used for detecting deadlocks from nested atomics & illegal vectorization. */
20254 bool inside_atomic_mutex_node;
20255
20256 /** Emit atomic store instructions? */
20257 bool emit_atomic_stores;
20258
20259private:
20260 /** All the values in scope at the current code location during
20261 * codegen. Use sym_push and sym_pop to access. */
20262 Scope<llvm::Value *> symbol_table;
20263
20264 /** String constants already emitted to the module. Tracked to
20265 * prevent emitting the same string many times. */
20266 std::map<std::string, llvm::Constant *> string_constants;
20267
20268 /** A basic block to branch to on error that triggers all
20269 * destructors. As destructors are registered, code gets added
20270 * to this block. */
20271 llvm::BasicBlock *destructor_block;
20272
20273 /** Turn off all unsafe math flags in scopes while this is set. */
20274 bool strict_float;
20275
20276 /** Use the LLVM large code model when this is set. */
20277 bool llvm_large_code_model;
20278
20279 /** Embed an instance of halide_filter_metadata_t in the code, using
20280 * the given name (by convention, this should be ${FUNCTIONNAME}_metadata)
20281 * as extern "C" linkage. Note that the return value is a function-returning-
20282 * pointer-to-constant-data.
20283 */
20284 llvm::Function *embed_metadata_getter(const std::string &metadata_getter_name,
20285 const std::string &function_name, const std::vector<LoweredArgument> &args,
20286 const std::map<std::string, std::string> &metadata_name_map);
20287
20288 /** Embed a constant expression as a global variable. */
20289 llvm::Constant *embed_constant_expr(Expr e, llvm::Type *t);
20290 llvm::Constant *embed_constant_scalar_value_t(const Expr &e);
20291
20292 llvm::Function *add_argv_wrapper(llvm::Function *fn, const std::string &name, bool result_in_argv = false);
20293
20294 llvm::Value *codegen_dense_vector_load(const Type &type, const std::string &name, const Expr &base,
20295 const Buffer<> &image, const Parameter &param, const ModulusRemainder &alignment,
20296 llvm::Value *vpred = nullptr, bool slice_to_native = true);
20297 llvm::Value *codegen_dense_vector_load(const Load *load, llvm::Value *vpred = nullptr, bool slice_to_native = true);
20298
20299 virtual void codegen_predicated_vector_load(const Load *op);
20300 virtual void codegen_predicated_vector_store(const Store *op);
20301
20302 void codegen_atomic_rmw(const Store *op);
20303
20304 void init_codegen(const std::string &name, bool any_strict_float = false);
20305 std::unique_ptr<llvm::Module> finish_codegen();
20306
20307 /** A helper routine for generating folded vector reductions. */
20308 template<typename Op>
20309 bool try_to_fold_vector_reduce(const Expr &a, Expr b);
20310};
20311
20312} // namespace Internal
20313
20314/** Given a Halide module, generate an llvm::Module. */
20315std::unique_ptr<llvm::Module> codegen_llvm(const Module &module,
20316 llvm::LLVMContext &context);
20317
20318} // namespace Halide
20319
20320#endif
20321#ifndef HALIDE_CODEGEN_METAL_DEV_H
20322#define HALIDE_CODEGEN_METAL_DEV_H
20323
20324/** \file
20325 * Defines the code-generator for producing Apple Metal shading language kernel code
20326 */
20327
20328#include <memory>
20329
20330namespace Halide {
20331
20332struct Target;
20333
20334namespace Internal {
20335
20336struct CodeGen_GPU_Dev;
20337
20338std::unique_ptr<CodeGen_GPU_Dev> new_CodeGen_Metal_Dev(const Target &target);
20339
20340} // namespace Internal
20341} // namespace Halide
20342
20343#endif
20344#ifndef HALIDE_CODEGEN_OPENCL_DEV_H
20345#define HALIDE_CODEGEN_OPENCL_DEV_H
20346
20347/** \file
20348 * Defines the code-generator for producing OpenCL C kernel code
20349 */
20350
20351#include <memory>
20352
20353namespace Halide {
20354
20355struct Target;
20356
20357namespace Internal {
20358
20359struct CodeGen_GPU_Dev;
20360
20361std::unique_ptr<CodeGen_GPU_Dev> new_CodeGen_OpenCL_Dev(const Target &target);
20362
20363} // namespace Internal
20364} // namespace Halide
20365
20366#endif
20367#ifndef HALIDE_CODEGEN_OPENGLCOMPUTE_DEV_H
20368#define HALIDE_CODEGEN_OPENGLCOMPUTE_DEV_H
20369
20370/** \file
20371 * Defines the code-generator for producing GLSL kernel code for OpenGL Compute.
20372 */
20373
20374#include <memory>
20375
20376namespace Halide {
20377
20378struct Target;
20379
20380namespace Internal {
20381
20382struct CodeGen_GPU_Dev;
20383
20384std::unique_ptr<CodeGen_GPU_Dev> new_CodeGen_OpenGLCompute_Dev(const Target &target);
20385
20386} // namespace Internal
20387} // namespace Halide
20388
20389#endif
20390#ifndef HALIDE_CODEGEN_POSIX_H
20391#define HALIDE_CODEGEN_POSIX_H
20392
20393/** \file
20394 * Defines a base-class for code-generators on posixy cpu platforms
20395 */
20396
20397
20398namespace Halide {
20399namespace Internal {
20400
20401/** A code generator that emits posix code from a given Halide stmt. */
20402class CodeGen_Posix : public CodeGen_LLVM {
20403public:
20404 /** Create an posix code generator. Processor features can be
20405 * enabled using the appropriate arguments */
20406 CodeGen_Posix(const Target &t);
20407
20408protected:
20409 using CodeGen_LLVM::visit;
20410
20411 /** Posix implementation of Allocate. Small constant-sized allocations go
20412 * on the stack. The rest go on the heap by calling "halide_malloc"
20413 * and "halide_free" in the standard library. */
20414 // @{
20415 void visit(const Allocate *) override;
20416 void visit(const Free *) override;
20417 // @}
20418
20419 /** It can be convenient for backends to assume there is extra
20420 * padding beyond the end of a buffer to enable faster
20421 * loads/stores. This function gets the padding required by the
20422 * implementing target. */
20423 virtual int allocation_padding(Type type) const;
20424
20425 /** A struct describing heap or stack allocations. */
20426 struct Allocation {
20427 /** The memory */
20428 llvm::Value *ptr = nullptr;
20429
20430 /** Destructor stack slot for this allocation. */
20431 llvm::Value *destructor = nullptr;
20432
20433 /** Function to accomplish the destruction. */
20434 llvm::Function *destructor_function = nullptr;
20435
20436 /** Pseudostack slot for this allocation. Non-null for
20437 * allocations of type Stack with dynamic size. */
20438 llvm::Value *pseudostack_slot = nullptr;
20439
20440 /** The (Halide) type of the allocation. */
20441 Type type;
20442
20443 /** How many bytes this allocation is, or 0 if not
20444 * constant. */
20445 int constant_bytes = 0;
20446
20447 /** How many bytes of stack space used. 0 implies it was a
20448 * heap allocation. */
20449 int stack_bytes = 0;
20450
20451 /** A unique name for this allocation. May not be equal to the
20452 * Allocate node name in cases where we detect multiple
20453 * Allocate nodes can share a single allocation. */
20454 std::string name;
20455 };
20456
20457 /** The allocations currently in scope. The stack gets pushed when
20458 * we enter a new function. */
20459 Scope<Allocation> allocations;
20460
20461 std::string get_allocation_name(const std::string &n) override;
20462
20463private:
20464 /** Stack allocations that were freed, but haven't gone out of
20465 * scope yet. This allows us to re-use stack allocations when
20466 * they aren't being used. */
20467 std::vector<Allocation> free_stack_allocs;
20468
20469 /** current size of all alloca instances in use; this is tracked only
20470 * for debug output purposes. */
20471 size_t cur_stack_alloc_total{0};
20472
20473 /** Generates code for computing the size of an allocation from a
20474 * list of its extents and its size. Fires a runtime assert
20475 * (halide_error) if the size overflows 2^31 -1, the maximum
20476 * positive number an int32_t can hold. */
20477 llvm::Value *codegen_allocation_size(const std::string &name, Type type, const std::vector<Expr> &extents, const Expr &condition);
20478
20479 /** Allocates some memory on either the stack or the heap, and
20480 * returns an Allocation object describing it. For heap
20481 * allocations this calls halide_malloc in the runtime, and for
20482 * stack allocations it either reuses an existing block from the
20483 * free_stack_blocks list, or it saves the stack pointer and calls
20484 * alloca.
20485 *
20486 * This call returns the allocation, pushes it onto the
20487 * 'allocations' map, and adds an entry to the symbol table called
20488 * name.host that provides the base pointer.
20489 *
20490 * When the allocation can be freed call 'free_allocation', and
20491 * when it goes out of scope call 'destroy_allocation'. */
20492 Allocation create_allocation(const std::string &name, Type type, MemoryType memory_type,
20493 const std::vector<Expr> &extents,
20494 const Expr &condition, const Expr &new_expr, std::string free_function);
20495
20496 /** Free an allocation previously allocated with
20497 * create_allocation */
20498 void free_allocation(const std::string &name);
20499};
20500
20501} // namespace Internal
20502} // namespace Halide
20503
20504#endif
20505#ifndef HALIDE_CODEGEN_PTX_DEV_H
20506#define HALIDE_CODEGEN_PTX_DEV_H
20507
20508/** \file
20509 * Defines the code-generator for producing CUDA host code
20510 */
20511
20512#include <memory>
20513
20514namespace Halide {
20515
20516struct Target;
20517
20518namespace Internal {
20519
20520struct CodeGen_GPU_Dev;
20521
20522std::unique_ptr<CodeGen_GPU_Dev> new_CodeGen_PTX_Dev(const Target &target);
20523
20524} // namespace Internal
20525} // namespace Halide
20526
20527#endif
20528#ifndef HALIDE_CODEGEN_PYTORCH_H
20529#define HALIDE_CODEGEN_PYTORCH_H
20530
20531/** \file
20532 *
20533 * Defines an IRPrinter that emits C++ code that:
20534 * 1. wraps PyTorch's C++ tensor into Halide * buffers,
20535 * 2. calls the corresponding Halide operator.
20536 * 3. maps the output buffer back to a PyTorch tensor.
20537 *
20538 * The generated code checks for runtime errors and raises PyTorch exception
20539 * accordingly. It also makes sure the GPU device and stream are consistent when
20540 * the PyTorch input, when applicable.
20541 */
20542
20543
20544namespace Halide {
20545
20546class Module;
20547
20548namespace Internal {
20549
20550struct LoweredFunc;
20551
20552/** This class emits C++ code to wrap a Halide pipeline so that it can
20553 * be used as a C++ extension operator in PyTorch.
20554 */
20555class CodeGen_PyTorch : public IRPrinter {
20556public:
20557 CodeGen_PyTorch(std::ostream &dest);
20558 ~CodeGen_PyTorch() override = default;
20559
20560 /** Emit the PyTorch C++ wrapper for the Halide pipeline. */
20561 void compile(const Module &module);
20562
20563 static void test();
20564
20565private:
20566 void compile(const LoweredFunc &func, bool is_cuda);
20567};
20568
20569} // namespace Internal
20570} // namespace Halide
20571
20572#endif // HALIDE_CODEGEN_PYTORCH_H
20573#ifndef HALIDE_CODEGEN_TARGETS_H
20574#define HALIDE_CODEGEN_TARGETS_H
20575
20576/** \file
20577 * Provides constructors for code generators for various targets.
20578 */
20579
20580#include <memory>
20581
20582namespace Halide {
20583
20584struct Target;
20585
20586namespace Internal {
20587
20588class CodeGen_Posix;
20589
20590/** Construct CodeGen object for a variety of targets. */
20591std::unique_ptr<CodeGen_Posix> new_CodeGen_ARM(const Target &target);
20592std::unique_ptr<CodeGen_Posix> new_CodeGen_Hexagon(const Target &target);
20593std::unique_ptr<CodeGen_Posix> new_CodeGen_MIPS(const Target &target);
20594std::unique_ptr<CodeGen_Posix> new_CodeGen_PowerPC(const Target &target);
20595std::unique_ptr<CodeGen_Posix> new_CodeGen_RISCV(const Target &target);
20596std::unique_ptr<CodeGen_Posix> new_CodeGen_X86(const Target &target);
20597std::unique_ptr<CodeGen_Posix> new_CodeGen_WebAssembly(const Target &target);
20598
20599} // namespace Internal
20600} // namespace Halide
20601
20602#endif
20603#ifndef HALIDE_COMPILER_LOGGER_H_
20604#define HALIDE_COMPILER_LOGGER_H_
20605
20606/** \file
20607 * Defines an interface used to gather and log compile-time information, stats, etc
20608 * for use in evaluating internal Halide compilation rules and efficiency.
20609 *
20610 * The 'standard' implementation simply logs all gathered data to
20611 * a local file (in JSON form), but the entire implementation can be
20612 * replaced by custom definitions if you have unusual logging needs.
20613 */
20614
20615#include <iostream>
20616#include <map>
20617#include <memory>
20618#include <string>
20619#include <utility>
20620
20621
20622namespace Halide {
20623namespace Internal {
20624
20625class CompilerLogger {
20626public:
20627 /** The "Phase" of compilation, used for some calls */
20628 enum class Phase {
20629 HalideLowering,
20630 LLVM,
20631 };
20632
20633 CompilerLogger() = default;
20634 virtual ~CompilerLogger() = default;
20635
20636 /** Record when a particular simplifier rule matches.
20637 */
20638 virtual void record_matched_simplifier_rule(const std::string &rulename, Expr expr) = 0;
20639
20640 /** Record when an expression is non-monotonic in a loop variable.
20641 */
20642 virtual void record_non_monotonic_loop_var(const std::string &loop_var, Expr expr) = 0;
20643
20644 /** Record when can_prove() fails, but cannot find a counterexample.
20645 */
20646 virtual void record_failed_to_prove(Expr failed_to_prove, Expr original_expr) = 0;
20647
20648 /** Record total size (in bytes) of final generated object code (e.g., file size of .o output).
20649 */
20650 virtual void record_object_code_size(uint64_t bytes) = 0;
20651
20652 /** Record the compilation time (in seconds) for a given phase.
20653 */
20654 virtual void record_compilation_time(Phase phase, double duration) = 0;
20655
20656 /**
20657 * Emit all the gathered data to the given stream. This may be called multiple times.
20658 */
20659 virtual std::ostream &emit_to_stream(std::ostream &o) = 0;
20660};
20661
20662/** Set the active CompilerLogger object, replacing any existing one.
20663 * It is legal to pass in a nullptr (which means "don't do any compiler logging").
20664 * Returns the previous CompilerLogger (if any). */
20665std::unique_ptr<CompilerLogger> set_compiler_logger(std::unique_ptr<CompilerLogger> compiler_logger);
20666
20667/** Return the currently active CompilerLogger object. If set_compiler_logger()
20668 * has never been called, a nullptr implementation will be returned.
20669 * Do not save the pointer returned! It is intended to be used for immediate
20670 * calls only. */
20671CompilerLogger *get_compiler_logger();
20672
20673/** JSONCompilerLogger is a basic implementation of the CompilerLogger interface
20674 * that saves logged data, then logs it all in JSON format in emit_to_stream().
20675 */
20676class JSONCompilerLogger : public CompilerLogger {
20677public:
20678 JSONCompilerLogger() = default;
20679
20680 JSONCompilerLogger(
20681 const std::string &generator_name,
20682 const std::string &function_name,
20683 const std::string &autoscheduler_name,
20684 const Target &target,
20685 const std::string &generator_args,
20686 bool obfuscate_exprs);
20687
20688 void record_matched_simplifier_rule(const std::string &rulename, Expr expr) override;
20689 void record_non_monotonic_loop_var(const std::string &loop_var, Expr expr) override;
20690 void record_failed_to_prove(Expr failed_to_prove, Expr original_expr) override;
20691 void record_object_code_size(uint64_t bytes) override;
20692 void record_compilation_time(Phase phase, double duration) override;
20693
20694 std::ostream &emit_to_stream(std::ostream &o) override;
20695
20696protected:
20697 const std::string generator_name;
20698 const std::string function_name;
20699 const std::string autoscheduler_name;
20700 const Target target = Target();
20701 const std::string generator_args;
20702 const bool obfuscate_exprs{false};
20703
20704 // Maps from string representing rewrite rule -> list of Exprs that matched that rule
20705 std::map<std::string, std::vector<Expr>> matched_simplifier_rules;
20706
20707 // Maps loop_var -> list of Exprs that were nonmonotonic for that loop_var
20708 std::map<std::string, std::vector<Expr>> non_monotonic_loop_vars;
20709
20710 // List of (unprovable simplified Expr, original version of that Expr passed to can_prove()).
20711 std::vector<std::pair<Expr, Expr>> failed_to_prove_exprs;
20712
20713 // Total code size generated, in bytes.
20714 uint64_t object_code_size{0};
20715
20716 // Map of the time take for each phase of compilation.
20717 std::map<Phase, double> compilation_time;
20718
20719 void obfuscate();
20720 void emit();
20721};
20722
20723} // namespace Internal
20724} // namespace Halide
20725
20726#endif // HALIDE_COMPILER_LOGGER_H_
20727#ifndef HALIDE_CONCISE_CASTS_H
20728#define HALIDE_CONCISE_CASTS_H
20729
20730
20731/** \file
20732 *
20733 * Defines concise cast and saturating cast operators to make it
20734 * easier to read cast-heavy code. Think carefully about the
20735 * readability implications before using these. They could make your
20736 * code better or worse. Often it's better to add extra Funcs to your
20737 * pipeline that do the upcasting and downcasting.
20738 */
20739
20740namespace Halide {
20741namespace ConciseCasts {
20742
20743inline Expr f64(Expr e) {
20744 Type t = Float(64, e.type().lanes());
20745 return cast(t, std::move(e));
20746}
20747
20748inline Expr f32(Expr e) {
20749 Type t = Float(32, e.type().lanes());
20750 return cast(t, std::move(e));
20751}
20752
20753inline Expr bf16(Expr e) {
20754 Type t = BFloat(16, e.type().lanes());
20755 return cast(t, std::move(e));
20756}
20757
20758inline Expr i64(Expr e) {
20759 Type t = Int(64, e.type().lanes());
20760 return cast(t, std::move(e));
20761}
20762
20763inline Expr i32(Expr e) {
20764 Type t = Int(32, e.type().lanes());
20765 return cast(t, std::move(e));
20766}
20767
20768inline Expr i16(Expr e) {
20769 Type t = Int(16, e.type().lanes());
20770 return cast(t, std::move(e));
20771}
20772
20773inline Expr i8(Expr e) {
20774 Type t = Int(8, e.type().lanes());
20775 return cast(t, std::move(e));
20776}
20777
20778inline Expr u64(Expr e) {
20779 Type t = UInt(64, e.type().lanes());
20780 return cast(t, std::move(e));
20781}
20782
20783inline Expr u32(Expr e) {
20784 Type t = UInt(32, e.type().lanes());
20785 return cast(t, std::move(e));
20786}
20787
20788inline Expr u16(Expr e) {
20789 Type t = UInt(16, e.type().lanes());
20790 return cast(t, std::move(e));
20791}
20792
20793inline Expr u8(Expr e) {
20794 Type t = UInt(8, e.type().lanes());
20795 return cast(t, std::move(e));
20796}
20797
20798inline Expr i8_sat(Expr e) {
20799 Type t = Int(8, e.type().lanes());
20800 return saturating_cast(t, std::move(e));
20801}
20802
20803inline Expr u8_sat(Expr e) {
20804 Type t = UInt(8, e.type().lanes());
20805 return saturating_cast(t, std::move(e));
20806}
20807
20808inline Expr i16_sat(Expr e) {
20809 Type t = Int(16, e.type().lanes());
20810 return saturating_cast(t, std::move(e));
20811}
20812
20813inline Expr u16_sat(Expr e) {
20814 Type t = UInt(16, e.type().lanes());
20815 return saturating_cast(t, std::move(e));
20816}
20817
20818inline Expr i32_sat(Expr e) {
20819 Type t = Int(32, e.type().lanes());
20820 return saturating_cast(t, std::move(e));
20821}
20822
20823inline Expr u32_sat(Expr e) {
20824 Type t = UInt(32, e.type().lanes());
20825 return saturating_cast(t, std::move(e));
20826}
20827
20828inline Expr i64_sat(Expr e) {
20829 Type t = Int(64, e.type().lanes());
20830 return saturating_cast(t, std::move(e));
20831}
20832
20833inline Expr u64_sat(Expr e) {
20834 Type t = UInt(64, e.type().lanes());
20835 return saturating_cast(t, std::move(e));
20836}
20837
20838}; // namespace ConciseCasts
20839}; // namespace Halide
20840
20841#endif
20842#ifndef HALIDE_CPLUSPLUS_MANGLE_H
20843#define HALIDE_CPLUSPLUS_MANGLE_H
20844
20845/** \file
20846 *
20847 * A simple function to get a C++ mangled function name for a function.
20848 */
20849#include <string>
20850#include <vector>
20851
20852
20853namespace Halide {
20854
20855struct ExternFuncArgument;
20856struct Target;
20857
20858namespace Internal {
20859
20860/** Return the mangled C++ name for a function.
20861 * The target parameter is used to decide on the C++
20862 * ABI/mangling style to use.
20863 */
20864std::string cplusplus_function_mangled_name(const std::string &name,
20865 const std::vector<std::string> &namespaces,
20866 Type return_type,
20867 const std::vector<ExternFuncArgument> &args,
20868 const Target &target);
20869
20870void cplusplus_mangle_test();
20871
20872} // namespace Internal
20873
20874} // namespace Halide
20875
20876#endif
20877#ifndef HALIDE_INTERNAL_CSE_H
20878#define HALIDE_INTERNAL_CSE_H
20879
20880/** \file
20881 * Defines a pass for introducing let expressions to wrap common sub-expressions. */
20882
20883
20884namespace Halide {
20885namespace Internal {
20886
20887/** Replace each common sub-expression in the argument with a
20888 * variable, and wrap the resulting expr in a let statement giving a
20889 * value to that variable.
20890 *
20891 * This is important to do within Halide (instead of punting to llvm),
20892 * because exprs that come in from the front-end are small when
20893 * considered as a graph, but combinatorially large when considered as
20894 * a tree. For an example of a such a case, see
20895 * test/code_explosion.cpp
20896 *
20897 * The last parameter determines whether all common subexpressions are
20898 * lifted, or only those that the simplifier would not subsitute back
20899 * in (e.g. addition of a constant).
20900 */
20901Expr common_subexpression_elimination(const Expr &, bool lift_all = false);
20902
20903/** Do common-subexpression-elimination on each expression in a
20904 * statement. Does not introduce let statements. */
20905Stmt common_subexpression_elimination(const Stmt &, bool lift_all = false);
20906
20907void cse_test();
20908
20909} // namespace Internal
20910} // namespace Halide
20911
20912#endif
20913#ifndef HALIDE_INTERNAL_DEBUG_ARGUMENTS_H
20914#define HALIDE_INTERNAL_DEBUG_ARGUMENTS_H
20915
20916
20917/** \file
20918 *
20919 * Defines a lowering pass that injects debug statements inside a
20920 * LoweredFunc. Intended to be used when Target::Debug is on.
20921 */
20922
20923namespace Halide {
20924namespace Internal {
20925
20926struct LoweredFunc;
20927
20928/** Injects debug prints in a LoweredFunc that describe the target and
20929 * arguments. Mutates the given func. */
20930void debug_arguments(LoweredFunc *func, const Target &t);
20931
20932} // namespace Internal
20933} // namespace Halide
20934
20935#endif
20936#ifndef HALIDE_DEBUG_TO_FILE_H
20937#define HALIDE_DEBUG_TO_FILE_H
20938
20939/** \file
20940 * Defines the lowering pass that injects code at the end of
20941 * every realization to dump functions to a file for debugging. */
20942
20943#include <map>
20944#include <vector>
20945
20946
20947namespace Halide {
20948namespace Internal {
20949
20950class Function;
20951
20952/** Takes a statement with Realize nodes still unlowered. If the
20953 * corresponding functions have a debug_file set, then inject code
20954 * that will dump the contents of those functions to a file after the
20955 * realization. */
20956Stmt debug_to_file(Stmt s,
20957 const std::vector<Function> &outputs,
20958 const std::map<std::string, Function> &env);
20959
20960} // namespace Internal
20961} // namespace Halide
20962
20963#endif
20964#ifndef DEINTERLEAVE_H
20965#define DEINTERLEAVE_H
20966
20967/** \file
20968 *
20969 * Defines methods for splitting up a vector into the even lanes and
20970 * the odd lanes. Useful for optimizing expressions such as select(x %
20971 * 2, f(x/2), g(x/2))
20972 */
20973
20974
20975namespace Halide {
20976namespace Internal {
20977
20978/** Extract the odd-numbered lanes in a vector */
20979Expr extract_odd_lanes(const Expr &a);
20980
20981/** Extract the even-numbered lanes in a vector */
20982Expr extract_even_lanes(const Expr &a);
20983
20984/** Extract the nth lane of a vector */
20985Expr extract_lane(const Expr &vec, int lane);
20986
20987/** Look through a statement for expressions of the form select(ramp %
20988 * 2 == 0, a, b) and replace them with calls to an interleave
20989 * intrinsic */
20990Stmt rewrite_interleavings(const Stmt &s);
20991
20992void deinterleave_vector_test();
20993
20994} // namespace Internal
20995} // namespace Halide
20996
20997#endif
20998#ifndef HALIDE_DERIVATIVE_H
20999#define HALIDE_DERIVATIVE_H
21000
21001/** \file
21002 * Automatic differentiation
21003 */
21004
21005
21006#include <map>
21007#include <string>
21008#include <vector>
21009
21010namespace Halide {
21011
21012/**
21013 * Helper structure storing the adjoints Func.
21014 * Use d(func) or d(buffer) to obtain the derivative Func.
21015 */
21016class Derivative {
21017public:
21018 // function name & update_id, for initialization update_id == -1
21019 using FuncKey = std::pair<std::string, int>;
21020
21021 explicit Derivative(const std::map<FuncKey, Func> &adjoints_in)
21022 : adjoints(adjoints_in) {
21023 }
21024 explicit Derivative(std::map<FuncKey, Func> &&adjoints_in)
21025 : adjoints(std::move(adjoints_in)) {
21026 }
21027
21028 // These all return an undefined Func if no derivative is found
21029 // (typically, if the input Funcs aren't differentiable)
21030 Func operator()(const Func &func, int update_id = -1) const;
21031 Func operator()(const Buffer<> &buffer) const;
21032 Func operator()(const Param<> &param) const;
21033
21034private:
21035 const std::map<FuncKey, Func> adjoints;
21036};
21037
21038/**
21039 * Given a Func and a corresponding adjoint, (back)propagate the
21040 * adjoint to all dependent Funcs, buffers, and parameters.
21041 * The bounds of output and adjoint need to be specified with pair {min, extent}
21042 * For each Func the output depends on, and for the pure definition and
21043 * each update of that Func, it generates a derivative Func stored in
21044 * the Derivative.
21045 */
21046Derivative propagate_adjoints(const Func &output,
21047 const Func &adjoint,
21048 const Region &output_bounds);
21049/**
21050 * Given a Func and a corresponding adjoint buffer, (back)propagate the
21051 * adjoint to all dependent Funcs, buffers, and parameters.
21052 * For each Func the output depends on, and for the pure definition and
21053 * each update of that Func, it generates a derivative Func stored in
21054 * the Derivative.
21055 */
21056Derivative propagate_adjoints(const Func &output,
21057 const Buffer<float> &adjoint);
21058/**
21059 * Given a scalar Func with size 1, (back)propagate the gradient
21060 * to all dependent Funcs, buffers, and parameters.
21061 * For each Func the output depends on, and for the pure definition and
21062 * each update of that Func, it generates a derivative Func stored in
21063 * the Derivative.
21064 */
21065Derivative propagate_adjoints(const Func &output);
21066
21067} // namespace Halide
21068
21069#endif
21070#ifndef HALIDE_INTERNAL_DERIVATIVE_UTILS_H
21071#define HALIDE_INTERNAL_DERIVATIVE_UTILS_H
21072
21073#include <set>
21074
21075
21076namespace Halide {
21077namespace Internal {
21078
21079/**
21080 * Remove all let definitions of expr
21081 */
21082Expr remove_let_definitions(const Expr &expr);
21083
21084/**
21085 * Return a list of variables' indices that expr depends on and are in the filter
21086 */
21087std::vector<int> gather_variables(const Expr &expr,
21088 const std::vector<std::string> &filter);
21089std::vector<int> gather_variables(const Expr &expr,
21090 const std::vector<Var> &filter);
21091
21092/**
21093 * Return a list of reduction variables the expression or tuple depends on
21094 */
21095struct ReductionVariableInfo {
21096 Expr min, extent;
21097 int index;
21098 ReductionDomain domain;
21099 std::string name;
21100};
21101std::map<std::string, ReductionVariableInfo> gather_rvariables(const Expr &expr);
21102std::map<std::string, ReductionVariableInfo> gather_rvariables(const Tuple &tuple);
21103/**
21104 * Add necessary let expressions to expr
21105 */
21106Expr add_let_expression(const Expr &expr,
21107 const std::map<std::string, Expr> &let_var_mapping,
21108 const std::vector<std::string> &let_variables);
21109/**
21110 * Topologically sort the expression graph expressed by expr
21111 */
21112std::vector<Expr> sort_expressions(const Expr &expr);
21113/**
21114 * Compute the bounds of funcs. The bounds represent a conservative region
21115 * that is used by the "consumers" of the function, except of itself.
21116 */
21117std::map<std::string, Box> inference_bounds(const std::vector<Func> &funcs,
21118 const std::vector<Box> &output_bounds);
21119std::map<std::string, Box> inference_bounds(const Func &func,
21120 const Box &output_bounds);
21121/**
21122 * Convert Box to vector of (min, extent)
21123 */
21124std::vector<std::pair<Expr, Expr>> box_to_vector(const Box &bounds);
21125/**
21126 * Return true if bounds0 and bounds1 represent the same bounds.
21127 */
21128bool equal(const RDom &bounds0, const RDom &bounds1);
21129/**
21130 * Return a list of variable names
21131 */
21132std::vector<std::string> vars_to_strings(const std::vector<Var> &vars);
21133/**
21134 * Return the reduction domain used by expr
21135 */
21136ReductionDomain extract_rdom(const Expr &expr);
21137/**
21138 * expr is new_var == f(var), solve for var == g(new_var)
21139 * if multiple new_var corresponds to same var, introduce a RDom
21140 */
21141std::pair<bool, Expr> solve_inverse(Expr expr,
21142 const std::string &new_var,
21143 const std::string &var);
21144/**
21145 * Find all calls to image buffers and parameters in the function
21146 */
21147struct BufferInfo {
21148 int dimension;
21149 Type type;
21150};
21151std::map<std::string, BufferInfo> find_buffer_param_calls(const Func &func);
21152/**
21153 * Find all implicit variables in expr
21154 */
21155std::set<std::string> find_implicit_variables(const Expr &expr);
21156/**
21157 * Substitute the variable. Also replace all occurrences in rdom.where() predicates.
21158 */
21159Expr substitute_rdom_predicate(
21160 const std::string &name, const Expr &replacement, const Expr &expr);
21161
21162/**
21163 * Return true if expr contains call to func_name
21164 */
21165bool is_calling_function(
21166 const std::string &func_name, const Expr &expr,
21167 const std::map<std::string, Expr> &let_var_mapping);
21168/**
21169 * Return true if expr depends on any function or buffer
21170 */
21171bool is_calling_function(
21172 const Expr &expr,
21173 const std::map<std::string, Expr> &let_var_mapping);
21174
21175/**
21176 * Replaces call to Func f in Expr e such that the call argument at variable_id
21177 * is the pure argument.
21178 */
21179Expr substitute_call_arg_with_pure_arg(Func f,
21180 int variable_id,
21181 const Expr &e);
21182
21183} // namespace Internal
21184} // namespace Halide
21185
21186#endif
21187#ifndef HALIDE_DIMENSION_H
21188#define HALIDE_DIMENSION_H
21189
21190/** \file
21191 * Defines the Dimension utility class for Halide pipelines
21192 */
21193
21194#include <utility>
21195
21196
21197namespace Halide {
21198namespace Internal {
21199
21200class Dimension {
21201public:
21202 /** Get an expression representing the minimum coordinates of this image
21203 * parameter in the given dimension. */
21204 Expr min() const;
21205
21206 /** Get an expression representing the extent of this image
21207 * parameter in the given dimension */
21208 Expr extent() const;
21209
21210 /** Get an expression representing the maximum coordinates of
21211 * this image parameter in the given dimension. */
21212 Expr max() const;
21213
21214 /** Get an expression representing the stride of this image in the
21215 * given dimension */
21216 Expr stride() const;
21217
21218 /** Set the min in a given dimension to equal the given
21219 * expression. Setting the mins to zero may simplify some
21220 * addressing math. */
21221 Dimension set_min(Expr min);
21222
21223 /** Set the extent in a given dimension to equal the given
21224 * expression. Images passed in that fail this check will generate
21225 * a runtime error. Returns a reference to the ImageParam so that
21226 * these calls may be chained.
21227 *
21228 * This may help the compiler generate better
21229 * code. E.g:
21230 \code
21231 im.dim(0).set_extent(100);
21232 \endcode
21233 * tells the compiler that dimension zero must be of extent 100,
21234 * which may result in simplification of boundary checks. The
21235 * value can be an arbitrary expression:
21236 \code
21237 im.dim(0).set_extent(im.dim(1).extent());
21238 \endcode
21239 * declares that im is a square image (of unknown size), whereas:
21240 \code
21241 im.dim(0).set_extent((im.dim(0).extent()/32)*32);
21242 \endcode
21243 * tells the compiler that the extent is a multiple of 32. */
21244 Dimension set_extent(Expr extent);
21245
21246 /** Set the stride in a given dimension to equal the given
21247 * value. This is particularly helpful to set when
21248 * vectorizing. Known strides for the vectorized dimension
21249 * generate better code. */
21250 Dimension set_stride(Expr stride);
21251
21252 /** Set the min and extent in one call. */
21253 Dimension set_bounds(Expr min, Expr extent);
21254
21255 /** Set the min and extent estimates in one call. These values are only
21256 * used by the auto-scheduler and/or the RunGen tool/ */
21257 Dimension set_estimate(Expr min, Expr extent);
21258
21259 Expr min_estimate() const;
21260 Expr extent_estimate() const;
21261
21262 /** Get a different dimension of the same buffer */
21263 // @{
21264 Dimension dim(int i) const;
21265 // @}
21266
21267private:
21268 friend class ::Halide::OutputImageParam;
21269
21270 /** Construct a Dimension representing dimension d of some
21271 * Internal::Parameter p. Only friends may construct
21272 * these. */
21273 Dimension(const Internal::Parameter &p, int d, Func f);
21274
21275 Parameter param;
21276 int d;
21277 Func f;
21278};
21279
21280} // namespace Internal
21281} // namespace Halide
21282
21283#endif
21284#ifndef HALIDE_EARLY_FREE_H
21285#define HALIDE_EARLY_FREE_H
21286
21287/** \file
21288 * Defines the lowering pass that injects markers just after
21289 * the last use of each buffer so that they can potentially be freed
21290 * earlier.
21291 */
21292
21293
21294namespace Halide {
21295namespace Internal {
21296
21297/** Take a statement with allocations and inject markers (of the form
21298 * of calls to "mark buffer dead") after the last use of each
21299 * allocation. Targets may use this to free buffers earlier than the
21300 * close of their Allocate node. */
21301Stmt inject_early_frees(const Stmt &s);
21302
21303} // namespace Internal
21304} // namespace Halide
21305
21306#endif
21307#ifndef HALIDE_ELF_H
21308#define HALIDE_ELF_H
21309
21310#include <algorithm>
21311#include <iterator>
21312#include <list>
21313#include <memory>
21314#include <string>
21315#include <utility>
21316#include <vector>
21317
21318namespace Halide {
21319namespace Internal {
21320namespace Elf {
21321
21322// This ELF parser mostly deserializes the object into a graph
21323// structure in memory. It replaces indices into tables (sections,
21324// symbols, etc.) with a weakly referenced graph of pointers. The
21325// Object datastructure owns all of the objects. This namespace exists
21326// because it is very difficult to use LLVM's object parser to modify
21327// an object (it's fine for parsing only). This was built using
21328// http://www.skyfree.org/linux/references/ELF_Format.pdf as a reference
21329// for the ELF structs and constants.
21330
21331class Object;
21332class Symbol;
21333class Section;
21334class Relocation;
21335
21336// Helpful wrapper to allow range-based for loops.
21337template<typename T>
21338class iterator_range {
21339 T b, e;
21340
21341public:
21342 iterator_range(T b, T e)
21343 : b(b), e(e) {
21344 }
21345
21346 T begin() const {
21347 return b;
21348 }
21349 T end() const {
21350 return e;
21351 }
21352};
21353
21354/** Describes a symbol */
21355class Symbol {
21356public:
21357 enum Binding : uint8_t {
21358 STB_LOCAL = 0,
21359 STB_GLOBAL = 1,
21360 STB_WEAK = 2,
21361 STB_LOPROC = 13,
21362 STB_HIPROC = 15,
21363 };
21364
21365 enum Type : uint8_t {
21366 STT_NOTYPE = 0,
21367 STT_OBJECT = 1,
21368 STT_FUNC = 2,
21369 STT_SECTION = 3,
21370 STT_FILE = 4,
21371 STT_LOPROC = 13,
21372 STT_HIPROC = 15,
21373 };
21374
21375 enum Visibility : uint8_t {
21376 STV_DEFAULT = 0,
21377 STV_INTERNAL = 1,
21378 STV_HIDDEN = 2,
21379 STV_PROTECTED = 3,
21380 };
21381
21382private:
21383 std::string name;
21384 const Section *definition = nullptr;
21385 uint64_t offset = 0;
21386 uint32_t size = 0;
21387 Binding binding = STB_LOCAL;
21388 Type type = STT_NOTYPE;
21389 Visibility visibility = STV_DEFAULT;
21390
21391public:
21392 Symbol() = default;
21393 Symbol(const std::string &name)
21394 : name(name) {
21395 }
21396
21397 /** Accesses the name of this symbol. */
21398 ///@{
21399 Symbol &set_name(const std::string &name) {
21400 this->name = name;
21401 return *this;
21402 }
21403 const std::string &get_name() const {
21404 return name;
21405 }
21406 ///@}
21407
21408 /** Accesses the type of this symbol. */
21409 ///@{
21410 Symbol &set_type(Type type) {
21411 this->type = type;
21412 return *this;
21413 }
21414 Type get_type() const {
21415 return type;
21416 }
21417 ///@}
21418
21419 /** Accesses the properties that describe the definition of this symbol. */
21420 ///@{
21421 Symbol &define(const Section *section, uint64_t offset, uint32_t size) {
21422 this->definition = section;
21423 this->offset = offset;
21424 this->size = size;
21425 return *this;
21426 }
21427 bool is_defined() const {
21428 return definition != nullptr;
21429 }
21430 const Section *get_section() const {
21431 return definition;
21432 }
21433 uint64_t get_offset() const {
21434 return offset;
21435 }
21436 uint32_t get_size() const {
21437 return size;
21438 }
21439 ///@}
21440
21441 /** Access the binding and visibility of this symbol. See the ELF
21442 * spec for more information about these properties. */
21443 ///@{
21444 Symbol &set_binding(Binding binding) {
21445 this->binding = binding;
21446 return *this;
21447 }
21448 Symbol &set_visibility(Visibility visibility) {
21449 this->visibility = visibility;
21450 return *this;
21451 }
21452 Binding get_binding() const {
21453 return binding;
21454 }
21455 Visibility get_visibility() const {
21456 return visibility;
21457 }
21458 ///@}
21459};
21460
21461/** Describes a relocation to be applied to an offset of a section in
21462 * an Object. */
21463class Relocation {
21464 uint32_t type = 0;
21465 uint64_t offset = 0;
21466 int64_t addend = 0;
21467 const Symbol *symbol = nullptr;
21468
21469public:
21470 Relocation() = default;
21471 Relocation(uint32_t type, uint64_t offset, int64_t addend, const Symbol *symbol)
21472 : type(type), offset(offset), addend(addend), symbol(symbol) {
21473 }
21474
21475 /** The type of relocation to be applied. The meaning of this
21476 * value depends on the machine of the object. */
21477 ///@{
21478 Relocation &set_type(uint32_t type) {
21479 this->type = type;
21480 return *this;
21481 }
21482 uint32_t get_type() const {
21483 return type;
21484 }
21485 ///@}
21486
21487 /** Where to apply the relocation. This is relative to the section
21488 * the relocation belongs to. */
21489 ///@{
21490 Relocation &set_offset(uint64_t offset) {
21491 this->offset = offset;
21492 return *this;
21493 }
21494 uint64_t get_offset() const {
21495 return offset;
21496 }
21497 ///@}
21498
21499 /** The value to replace with the relocation is the address of the symbol plus the addend. */
21500 ///@{
21501 Relocation &set_symbol(const Symbol *symbol) {
21502 this->symbol = symbol;
21503 return *this;
21504 }
21505 Relocation &set_addend(int64_t addend) {
21506 this->addend = addend;
21507 return *this;
21508 }
21509 const Symbol *get_symbol() const {
21510 return symbol;
21511 }
21512 int64_t get_addend() const {
21513 return addend;
21514 }
21515 ///@}
21516};
21517
21518/** Describes a section of an object file. */
21519class Section {
21520public:
21521 enum Type : uint32_t {
21522 SHT_NULL = 0,
21523 SHT_PROGBITS = 1,
21524 SHT_SYMTAB = 2,
21525 SHT_STRTAB = 3,
21526 SHT_RELA = 4,
21527 SHT_HASH = 5,
21528 SHT_DYNAMIC = 6,
21529 SHT_NOTE = 7,
21530 SHT_NOBITS = 8,
21531 SHT_REL = 9,
21532 SHT_SHLIB = 10,
21533 SHT_DYNSYM = 11,
21534 SHT_LOPROC = 0x70000000,
21535 SHT_HIPROC = 0x7fffffff,
21536 SHT_LOUSER = 0x80000000,
21537 SHT_HIUSER = 0xffffffff,
21538 };
21539
21540 enum Flag : uint32_t {
21541 SHF_WRITE = 0x1,
21542 SHF_ALLOC = 0x2,
21543 SHF_EXECINSTR = 0x4,
21544 SHF_MASKPROC = 0xf0000000,
21545 };
21546
21547 typedef std::vector<Relocation> RelocationList;
21548 typedef RelocationList::iterator relocation_iterator;
21549 typedef RelocationList::const_iterator const_relocation_iterator;
21550
21551 typedef std::vector<char>::iterator contents_iterator;
21552 typedef std::vector<char>::const_iterator const_contents_iterator;
21553
21554private:
21555 std::string name;
21556 Type type = SHT_NULL;
21557 uint32_t flags = 0;
21558 std::vector<char> contents;
21559 // Sections may have a size larger than the contents.
21560 uint64_t size = 0;
21561 uint64_t alignment = 1;
21562 RelocationList relocs;
21563
21564public:
21565 Section() = default;
21566 Section(const std::string &name, Type type)
21567 : name(name), type(type) {
21568 }
21569
21570 Section &set_name(const std::string &name) {
21571 this->name = name;
21572 return *this;
21573 }
21574 const std::string &get_name() const {
21575 return name;
21576 }
21577
21578 Section &set_type(Type type) {
21579 this->type = type;
21580 return *this;
21581 }
21582 Type get_type() const {
21583 return type;
21584 }
21585
21586 Section &set_flag(Flag flag) {
21587 this->flags |= flag;
21588 return *this;
21589 }
21590 Section &remove_flag(Flag flag) {
21591 this->flags &= ~flag;
21592 return *this;
21593 }
21594 Section &set_flags(uint32_t flags) {
21595 this->flags = flags;
21596 return *this;
21597 }
21598 uint32_t get_flags() const {
21599 return flags;
21600 }
21601 bool is_alloc() const {
21602 return (flags & SHF_ALLOC) != 0;
21603 }
21604 bool is_writable() const {
21605 return (flags & SHF_WRITE) != 0;
21606 }
21607
21608 /** Get or set the size of the section. The size may be larger
21609 * than the content. */
21610 ///@{
21611 Section &set_size(uint64_t size) {
21612 this->size = size;
21613 return *this;
21614 }
21615 uint64_t get_size() const {
21616 return std::max((uint64_t)size, (uint64_t)contents.size());
21617 }
21618 ///@}
21619
21620 Section &set_alignment(uint64_t alignment) {
21621 this->alignment = alignment;
21622 return *this;
21623 }
21624 uint64_t get_alignment() const {
21625 return alignment;
21626 }
21627
21628 Section &set_contents(std::vector<char> contents) {
21629 this->contents = std::move(contents);
21630 return *this;
21631 }
21632 template<typename It>
21633 Section &set_contents(It begin, It end) {
21634 this->contents.assign(begin, end);
21635 return *this;
21636 }
21637 template<typename It>
21638 Section &append_contents(It begin, It end) {
21639 this->contents.insert(this->contents.end(), begin, end);
21640 return *this;
21641 }
21642 template<typename It>
21643 Section &prepend_contents(It begin, It end) {
21644 typedef typename std::iterator_traits<It>::value_type T;
21645 uint64_t size_bytes = std::distance(begin, end) * sizeof(T);
21646 this->contents.insert(this->contents.begin(), begin, end);
21647
21648 // When we add data to the start of the section, we need to fix up
21649 // the offsets of the relocations linked to this section.
21650 for (Relocation &r : relocations()) {
21651 r.set_offset(r.get_offset() + size_bytes);
21652 }
21653
21654 return *this;
21655 }
21656 /** Set, append or prepend an object to the contents, assuming T is a
21657 * trivially copyable datatype. */
21658 template<typename T>
21659 Section &set_contents(const std::vector<T> &contents) {
21660 this->contents.assign((const char *)contents.data(), (const char *)(contents.data() + contents.size()));
21661 return *this;
21662 }
21663 template<typename T>
21664 Section &append_contents(const T &x) {
21665 return append_contents((const char *)&x, (const char *)(&x + 1));
21666 }
21667 template<typename T>
21668 Section &prepend_contents(const T &x) {
21669 return prepend_contents((const char *)&x, (const char *)(&x + 1));
21670 }
21671 const std::vector<char> &get_contents() const {
21672 return contents;
21673 }
21674 contents_iterator contents_begin() {
21675 return contents.begin();
21676 }
21677 contents_iterator contents_end() {
21678 return contents.end();
21679 }
21680 const_contents_iterator contents_begin() const {
21681 return contents.begin();
21682 }
21683 const_contents_iterator contents_end() const {
21684 return contents.end();
21685 }
21686 const char *contents_data() const {
21687 return contents.data();
21688 }
21689 size_t contents_size() const {
21690 return contents.size();
21691 }
21692 bool contents_empty() const {
21693 return contents.empty();
21694 }
21695
21696 Section &set_relocations(std::vector<Relocation> relocs) {
21697 this->relocs = std::move(relocs);
21698 return *this;
21699 }
21700 template<typename It>
21701 Section &set_relocations(It begin, It end) {
21702 this->relocs.assign(begin, end);
21703 return *this;
21704 }
21705 void add_relocation(const Relocation &reloc) {
21706 relocs.push_back(reloc);
21707 }
21708 relocation_iterator relocations_begin() {
21709 return relocs.begin();
21710 }
21711 relocation_iterator relocations_end() {
21712 return relocs.end();
21713 }
21714 iterator_range<relocation_iterator> relocations() {
21715 return {relocs.begin(), relocs.end()};
21716 }
21717 const_relocation_iterator relocations_begin() const {
21718 return relocs.begin();
21719 }
21720 const_relocation_iterator relocations_end() const {
21721 return relocs.end();
21722 }
21723 iterator_range<const_relocation_iterator> relocations() const {
21724 return {relocs.begin(), relocs.end()};
21725 }
21726 size_t relocations_size() const {
21727 return relocs.size();
21728 }
21729};
21730
21731/** Base class for a target architecture to implement the target
21732 * specific aspects of linking. */
21733class Linker {
21734public:
21735 virtual ~Linker() = default;
21736
21737 virtual uint16_t get_machine() = 0;
21738 virtual uint32_t get_flags() = 0;
21739 virtual uint32_t get_version() = 0;
21740 virtual void append_dynamic(Section &dynamic) = 0;
21741
21742 /** Add or get an entry to the global offset table (GOT) with a
21743 * relocation pointing to sym. */
21744 virtual uint64_t get_got_entry(Section &got, const Symbol &sym) = 0;
21745
21746 /** Check to see if this relocation should go through the PLT. */
21747 virtual bool needs_plt_entry(const Relocation &reloc) = 0;
21748
21749 /** Add a PLT entry for a symbol sym defined externally. Returns a
21750 * symbol representing the PLT entry. */
21751 virtual Symbol add_plt_entry(const Symbol &sym, Section &plt, Section &got,
21752 const Symbol &got_sym) = 0;
21753
21754 /** Perform a relocation. This function may opt to not apply the
21755 * relocation, and return a new relocation to be performed at
21756 * runtime. This requires that the section to apply the relocation
21757 * to is writable at runtime. */
21758 virtual Relocation relocate(uint64_t fixup_offset, char *fixup_addr, uint64_t type,
21759 const Symbol *sym, uint64_t sym_offset, int64_t addend,
21760 Section &got) = 0;
21761};
21762
21763/** Holds all of the relevant sections and symbols for an object. */
21764class Object {
21765public:
21766 enum Type : uint16_t {
21767 ET_NONE = 0,
21768 ET_REL = 1,
21769 ET_EXEC = 2,
21770 ET_DYN = 3,
21771 ET_CORE = 4,
21772 ET_LOPROC = 0xff00,
21773 ET_HIPROC = 0xffff,
21774 };
21775
21776 // We use lists for sections and symbols to avoid iterator
21777 // invalidation when we modify the containers.
21778 typedef std::list<Section> SectionList;
21779 typedef typename SectionList::iterator section_iterator;
21780 typedef typename SectionList::const_iterator const_section_iterator;
21781
21782 typedef std::list<Symbol> SymbolList;
21783 typedef typename SymbolList::iterator symbol_iterator;
21784 typedef typename SymbolList::const_iterator const_symbol_iterator;
21785
21786private:
21787 SectionList secs;
21788 SymbolList syms;
21789
21790 Type type = ET_NONE;
21791 uint16_t machine = 0;
21792 uint32_t version = 0;
21793 uint64_t entry = 0;
21794 uint32_t flags = 0;
21795
21796 Object(const Object &);
21797 void operator=(const Object &);
21798
21799public:
21800 Object() = default;
21801
21802 Type get_type() const {
21803 return type;
21804 }
21805 uint16_t get_machine() const {
21806 return machine;
21807 }
21808 uint32_t get_version() const {
21809 return version;
21810 }
21811 uint64_t get_entry() const {
21812 return entry;
21813 }
21814 uint32_t get_flags() const {
21815 return flags;
21816 }
21817
21818 Object &set_type(Type type) {
21819 this->type = type;
21820 return *this;
21821 }
21822 Object &set_machine(uint16_t machine) {
21823 this->machine = machine;
21824 return *this;
21825 }
21826 Object &set_version(uint32_t version) {
21827 this->version = version;
21828 return *this;
21829 }
21830 Object &set_entry(uint64_t entry) {
21831 this->entry = entry;
21832 return *this;
21833 }
21834 Object &set_flags(uint32_t flags) {
21835 this->flags = flags;
21836 return *this;
21837 }
21838
21839 /** Parse an object in memory to an Object. */
21840 static std::unique_ptr<Object> parse_object(const char *data, size_t size);
21841
21842 /** Write a shared object in memory. */
21843 std::vector<char> write_shared_object(Linker *linker, const std::vector<std::string> &depedencies = {},
21844 const std::string &soname = "");
21845
21846 section_iterator sections_begin() {
21847 return secs.begin();
21848 }
21849 section_iterator sections_end() {
21850 return secs.end();
21851 }
21852 iterator_range<section_iterator> sections() {
21853 return {secs.begin(), secs.end()};
21854 }
21855 const_section_iterator sections_begin() const {
21856 return secs.begin();
21857 }
21858 const_section_iterator sections_end() const {
21859 return secs.end();
21860 }
21861 iterator_range<const_section_iterator> sections() const {
21862 return {secs.begin(), secs.end()};
21863 }
21864 size_t sections_size() const {
21865 return secs.size();
21866 }
21867 section_iterator find_section(const std::string &name);
21868
21869 section_iterator add_section(const std::string &name, Section::Type type);
21870 section_iterator add_relocation_section(const Section &for_section);
21871 section_iterator erase_section(section_iterator i) {
21872 return secs.erase(i);
21873 }
21874
21875 section_iterator merge_sections(const std::vector<section_iterator> &sections);
21876 section_iterator merge_text_sections();
21877
21878 symbol_iterator symbols_begin() {
21879 return syms.begin();
21880 }
21881 symbol_iterator symbols_end() {
21882 return syms.end();
21883 }
21884 iterator_range<symbol_iterator> symbols() {
21885 return {syms.begin(), syms.end()};
21886 }
21887 const_symbol_iterator symbols_begin() const {
21888 return syms.begin();
21889 }
21890 const_symbol_iterator symbols_end() const {
21891 return syms.end();
21892 }
21893 iterator_range<const_symbol_iterator> symbols() const {
21894 return {syms.begin(), syms.end()};
21895 }
21896 size_t symbols_size() const {
21897 return syms.size();
21898 }
21899 symbol_iterator find_symbol(const std::string &name);
21900 const_symbol_iterator find_symbol(const std::string &name) const;
21901
21902 symbol_iterator add_symbol(const std::string &name);
21903};
21904
21905} // namespace Elf
21906} // namespace Internal
21907} // namespace Halide
21908
21909#endif
21910#ifndef HALIDE_IR_ELIMINATE_BOOL_VECTORS_H
21911#define HALIDE_IR_ELIMINATE_BOOL_VECTORS_H
21912
21913/** \file
21914 * Method to eliminate vectors of booleans from IR.
21915 */
21916
21917
21918namespace Halide {
21919namespace Internal {
21920
21921/** Some targets treat vectors of bools as integers of the same type that the
21922 * boolean operation is being used to operate on. For example, instead of
21923 * select(i1x8, u16x8, u16x8), the target would prefer to see select(u16x8,
21924 * u16x8, u16x8), where the first argument is a vector of integers representing
21925 * a mask. This pass converts vectors of bools to vectors of integers to meet
21926 * this requirement. This is done by injecting intrinsics to convert bools to
21927 * architecture-specific masks, and using a select_mask intrinsic instead of a
21928 * Select node. This also converts any intrinsics that operate on vectorized
21929 * conditions to a *_mask equivalent (if_then_else, require). Because the masks
21930 * are architecture specific, they may not be stored or loaded. On Stores, the
21931 * masks are converted to UInt(8) with a value of 0 or 1, which is our canonical
21932 * in-memory representation of a bool. */
21933///@{
21934Stmt eliminate_bool_vectors(const Stmt &s);
21935Expr eliminate_bool_vectors(const Expr &s);
21936///@}
21937
21938/** If a type is a boolean vector, find the type that it has been
21939 * changed to by eliminate_bool_vectors. */
21940inline Type eliminated_bool_type(Type bool_type, Type other_type) {
21941 if (bool_type.is_vector() && bool_type.bits() == 1) {
21942 bool_type = bool_type.with_code(Type::Int).with_bits(other_type.bits());
21943 }
21944 return bool_type;
21945}
21946
21947} // namespace Internal
21948} // namespace Halide
21949
21950#endif
21951#ifndef HALIDE_EMULATE_FLOAT16_MATH_H
21952#define HALIDE_EMULATE_FLOAT16_MATH_H
21953
21954/** \file
21955 * Methods for dealing with float16 arithmetic using float32 math, by
21956 * casting back and forth with bit tricks.
21957 */
21958
21959
21960namespace Halide {
21961namespace Internal {
21962
21963/** Check if a call is a float16 transcendental (e.g. sqrt_f16) */
21964bool is_float16_transcendental(const Call *);
21965
21966/** Implement a float16 transcendental using the float32 equivalent. */
21967Expr lower_float16_transcendental_to_float32_equivalent(const Call *);
21968
21969/** Cast to/from float and bfloat using bitwise math. */
21970//@{
21971Expr float32_to_bfloat16(Expr e);
21972Expr float32_to_float16(Expr e);
21973Expr float16_to_float32(Expr e);
21974Expr bfloat16_to_float32(Expr e);
21975Expr lower_float16_cast(const Cast *op);
21976//@}
21977
21978} // namespace Internal
21979} // namespace Halide
21980
21981#endif
21982#ifndef HALIDE_EXPR_USES_VAR_H
21983#define HALIDE_EXPR_USES_VAR_H
21984
21985/** \file
21986 * Defines a method to determine if an expression depends on some variables.
21987 */
21988
21989
21990namespace Halide {
21991namespace Internal {
21992
21993template<typename T = void>
21994class ExprUsesVars : public IRGraphVisitor {
21995 using IRGraphVisitor::visit;
21996
21997 const Scope<T> &vars;
21998 Scope<Expr> scope;
21999
22000 void include(const Expr &e) override {
22001 if (result) {
22002 return;
22003 }
22004 IRGraphVisitor::include(e);
22005 }
22006
22007 void include(const Stmt &s) override {
22008 if (result) {
22009 return;
22010 }
22011 IRGraphVisitor::include(s);
22012 }
22013
22014 void visit_name(const std::string &name) {
22015 if (vars.contains(name)) {
22016 result = true;
22017 } else if (scope.contains(name)) {
22018 include(scope.get(name));
22019 }
22020 }
22021
22022 void visit(const Variable *op) override {
22023 visit_name(op->name);
22024 }
22025
22026 void visit(const Load *op) override {
22027 visit_name(op->name);
22028 IRGraphVisitor::visit(op);
22029 }
22030
22031 void visit(const Store *op) override {
22032 visit_name(op->name);
22033 IRGraphVisitor::visit(op);
22034 }
22035
22036 void visit(const Call *op) override {
22037 visit_name(op->name);
22038 IRGraphVisitor::visit(op);
22039 }
22040
22041 void visit(const Provide *op) override {
22042 visit_name(op->name);
22043 IRGraphVisitor::visit(op);
22044 }
22045
22046 void visit(const LetStmt *op) override {
22047 visit_name(op->name);
22048 IRGraphVisitor::visit(op);
22049 }
22050
22051 void visit(const Let *op) override {
22052 visit_name(op->name);
22053 IRGraphVisitor::visit(op);
22054 }
22055
22056 void visit(const Realize *op) override {
22057 visit_name(op->name);
22058 IRGraphVisitor::visit(op);
22059 }
22060
22061 void visit(const Allocate *op) override {
22062 visit_name(op->name);
22063 IRGraphVisitor::visit(op);
22064 }
22065
22066public:
22067 ExprUsesVars(const Scope<T> &v, const Scope<Expr> *s = nullptr)
22068 : vars(v), result(false) {
22069 scope.set_containing_scope(s);
22070 }
22071 bool result;
22072};
22073
22074/** Test if a statement or expression references or defines any of the
22075 * variables in a scope, additionally considering variables bound to
22076 * Expr's in the scope provided in the final argument.
22077 */
22078template<typename StmtOrExpr, typename T>
22079inline bool stmt_or_expr_uses_vars(const StmtOrExpr &e, const Scope<T> &v,
22080 const Scope<Expr> &s = Scope<Expr>::empty_scope()) {
22081 ExprUsesVars<T> uses(v, &s);
22082 e.accept(&uses);
22083 return uses.result;
22084}
22085
22086/** Test if a statement or expression references or defines the given
22087 * variable, additionally considering variables bound to Expr's in the
22088 * scope provided in the final argument.
22089 */
22090template<typename StmtOrExpr>
22091inline bool stmt_or_expr_uses_var(const StmtOrExpr &e, const std::string &v,
22092 const Scope<Expr> &s = Scope<Expr>::empty_scope()) {
22093 Scope<> vars;
22094 vars.push(v);
22095 return stmt_or_expr_uses_vars<StmtOrExpr, void>(e, vars, s);
22096}
22097
22098/** Test if an expression references or defines the given variable,
22099 * additionally considering variables bound to Expr's in the scope
22100 * provided in the final argument.
22101 */
22102inline bool expr_uses_var(const Expr &e, const std::string &v,
22103 const Scope<Expr> &s = Scope<Expr>::empty_scope()) {
22104 return stmt_or_expr_uses_var(e, v, s);
22105}
22106
22107/** Test if a statement references or defines the given variable,
22108 * additionally considering variables bound to Expr's in the scope
22109 * provided in the final argument.
22110 */
22111inline bool stmt_uses_var(const Stmt &stmt, const std::string &v,
22112 const Scope<Expr> &s = Scope<Expr>::empty_scope()) {
22113 return stmt_or_expr_uses_var(stmt, v, s);
22114}
22115
22116/** Test if an expression references or defines any of the variables
22117 * in a scope, additionally considering variables bound to Expr's in
22118 * the scope provided in the final argument.
22119 */
22120template<typename T>
22121inline bool expr_uses_vars(const Expr &e, const Scope<T> &v,
22122 const Scope<Expr> &s = Scope<Expr>::empty_scope()) {
22123 return stmt_or_expr_uses_vars(e, v, s);
22124}
22125
22126/** Test if a statement references or defines any of the variables in
22127 * a scope, additionally considering variables bound to Expr's in the
22128 * scope provided in the final argument.
22129 */
22130template<typename T>
22131inline bool stmt_uses_vars(const Stmt &stmt, const Scope<T> &v,
22132 const Scope<Expr> &s = Scope<Expr>::empty_scope()) {
22133 return stmt_or_expr_uses_vars(stmt, v, s);
22134}
22135
22136} // namespace Internal
22137} // namespace Halide
22138
22139#endif
22140#ifndef HALIDE_EXTERN_H
22141#define HALIDE_EXTERN_H
22142
22143/** \file
22144 *
22145 * Convenience macros that lift functions that take C types into
22146 * functions that take and return exprs, and call the original
22147 * function at runtime under the hood. See test/c_function.cpp for
22148 * example usage.
22149 */
22150
22151
22152#define _halide_check_arg_type(t, name, e, n) \
22153 _halide_user_assert(e.type() == t) << "Type mismatch for argument " << n << " to extern function " << #name << ". Type expected is " << t << " but the argument " << e << " has type " << e.type() << ".\n";
22154
22155#define HalideExtern_1(rt, name, t1) \
22156 Halide::Expr name(const Halide::Expr &a1) { \
22157 _halide_check_arg_type(Halide::type_of<t1>(), name, a1, 1); \
22158 return Halide::Internal::Call::make(Halide::type_of<rt>(), #name, {a1}, Halide::Internal::Call::Extern); \
22159 }
22160
22161#define HalideExtern_2(rt, name, t1, t2) \
22162 Halide::Expr name(const Halide::Expr &a1, const Halide::Expr &a2) { \
22163 _halide_check_arg_type(Halide::type_of<t1>(), name, a1, 1); \
22164 _halide_check_arg_type(Halide::type_of<t2>(), name, a2, 2); \
22165 return Halide::Internal::Call::make(Halide::type_of<rt>(), #name, {a1, a2}, Halide::Internal::Call::Extern); \
22166 }
22167
22168#define HalideExtern_3(rt, name, t1, t2, t3) \
22169 Halide::Expr name(const Halide::Expr &a1, const Halide::Expr &a2, const Halide::Expr &a3) { \
22170 _halide_check_arg_type(Halide::type_of<t1>(), name, a1, 1); \
22171 _halide_check_arg_type(Halide::type_of<t2>(), name, a2, 2); \
22172 _halide_check_arg_type(Halide::type_of<t3>(), name, a3, 3); \
22173 return Halide::Internal::Call::make(Halide::type_of<rt>(), #name, {a1, a2, a3}, Halide::Internal::Call::Extern); \
22174 }
22175
22176#define HalideExtern_4(rt, name, t1, t2, t3, t4) \
22177 Halide::Expr name(const Halide::Expr &a1, const Halide::Expr &a2, const Halide::Expr &a3, const Halide::Expr &a4) { \
22178 _halide_check_arg_type(Halide::type_of<t1>(), name, a1, 1); \
22179 _halide_check_arg_type(Halide::type_of<t2>(), name, a2, 2); \
22180 _halide_check_arg_type(Halide::type_of<t3>(), name, a3, 3); \
22181 _halide_check_arg_type(Halide::type_of<t4>(), name, a4, 4); \
22182 return Halide::Internal::Call::make(Halide::type_of<rt>(), #name, {a1, a2, a3, a4}, Halide::Internal::Call::Extern); \
22183 }
22184
22185#define HalideExtern_5(rt, name, t1, t2, t3, t4, t5) \
22186 Halide::Expr name(const Halide::Expr &a1, const Halide::Expr &a2, const Halide::Expr &a3, const Halide::Expr &a4, const Halide::Expr &a5) { \
22187 _halide_check_arg_type(Halide::type_of<t1>(), name, a1, 1); \
22188 _halide_check_arg_type(Halide::type_of<t2>(), name, a2, 2); \
22189 _halide_check_arg_type(Halide::type_of<t3>(), name, a3, 3); \
22190 _halide_check_arg_type(Halide::type_of<t4>(), name, a4, 4); \
22191 _halide_check_arg_type(Halide::type_of<t5>(), name, a5, 5); \
22192 return Halide::Internal::Call::make(Halide::type_of<rt>(), #name, {a1, a2, a3, a4, a5}, Halide::Internal::Call::Extern); \
22193 }
22194
22195#define HalidePureExtern_1(rt, name, t1) \
22196 Halide::Expr name(const Halide::Expr &a1) { \
22197 _halide_check_arg_type(Halide::type_of<t1>(), name, a1, 1); \
22198 return Halide::Internal::Call::make(Halide::type_of<rt>(), #name, {a1}, Halide::Internal::Call::PureExtern); \
22199 }
22200
22201#define HalidePureExtern_2(rt, name, t1, t2) \
22202 Halide::Expr name(const Halide::Expr &a1, const Halide::Expr &a2) { \
22203 _halide_check_arg_type(Halide::type_of<t1>(), name, a1, 1); \
22204 _halide_check_arg_type(Halide::type_of<t2>(), name, a2, 2); \
22205 return Halide::Internal::Call::make(Halide::type_of<rt>(), #name, {a1, a2}, Halide::Internal::Call::PureExtern); \
22206 }
22207
22208#define HalidePureExtern_3(rt, name, t1, t2, t3) \
22209 Halide::Expr name(const Halide::Expr &a1, const Halide::Expr &a2, const Halide::Expr &a3) { \
22210 _halide_check_arg_type(Halide::type_of<t1>(), name, a1, 1); \
22211 _halide_check_arg_type(Halide::type_of<t2>(), name, a2, 2); \
22212 _halide_check_arg_type(Halide::type_of<t3>(), name, a3, 3); \
22213 return Halide::Internal::Call::make(Halide::type_of<rt>(), #name, {a1, a2, a3}, Halide::Internal::Call::PureExtern); \
22214 }
22215
22216#define HalidePureExtern_4(rt, name, t1, t2, t3, t4) \
22217 Halide::Expr name(const Halide::Expr &a1, const Halide::Expr &a2, const Halide::Expr &a3, const Halide::Expr &a4) { \
22218 _halide_check_arg_type(Halide::type_of<t1>(), name, a1, 1); \
22219 _halide_check_arg_type(Halide::type_of<t2>(), name, a2, 2); \
22220 _halide_check_arg_type(Halide::type_of<t3>(), name, a3, 3); \
22221 _halide_check_arg_type(Halide::type_of<t4>(), name, a4, 4); \
22222 return Halide::Internal::Call::make(Halide::type_of<rt>(), #name, {a1, a2, a3, a4}, Halide::Internal::Call::PureExtern); \
22223 }
22224
22225#define HalidePureExtern_5(rt, name, t1, t2, t3, t4, t5) \
22226 Halide::Expr name(const Halide::Expr &a1, const Halide::Expr &a2, const Halide::Expr &a3, const Halide::Expr &a4, const Halide::Expr &a5) { \
22227 _halide_check_arg_type(Halide::type_of<t1>(), name, a1, 1); \
22228 _halide_check_arg_type(Halide::type_of<t2>(), name, a2, 2); \
22229 _halide_check_arg_type(Halide::type_of<t3>(), name, a3, 3); \
22230 _halide_check_arg_type(Halide::type_of<t4>(), name, a4, 4); \
22231 _halide_check_arg_type(Halide::type_of<t5>(), name, a5, 5); \
22232 return Halide::Internal::Call::make(Halide::type_of<rt>(), #name, {a1, a2, a3, a4, a5}, Halide::Internal::Call::PureExtern); \
22233 }
22234#endif
22235#ifndef HALIDE_FAST_INTEGER_DIVIDE_H
22236#define HALIDE_FAST_INTEGER_DIVIDE_H
22237
22238
22239namespace Halide {
22240
22241/** Integer division by small values can be done exactly as multiplies
22242 * and shifts. This function does integer division for numerators of
22243 * various integer types (8, 16, 32 bit signed and unsigned)
22244 * numerators and uint8 denominators. The type of the result is the
22245 * type of the numerator. The unsigned version is faster than the
22246 * signed version, so cast the numerator to an unsigned int if you
22247 * know it's positive.
22248 *
22249 * If your divisor is compile-time constant, Halide performs a
22250 * slightly better optimization automatically, so there's no need to
22251 * use this function (but it won't hurt).
22252 *
22253 * This function vectorizes well on arm, and well on x86 for 16 and 8
22254 * bit vectors. For 32-bit vectors on x86 you're better off using
22255 * native integer division.
22256 *
22257 * Also, this routine treats division by zero as division by
22258 * 256. I.e. it interprets the uint8 divisor as a number from 1 to 256
22259 * inclusive.
22260 */
22261Expr fast_integer_divide(Expr numerator, Expr denominator);
22262
22263/** Use the fast integer division tables to implement a modulo
22264 * operation via the Euclidean identity: a%b = a - (a/b)*b
22265 */
22266Expr fast_integer_modulo(Expr numerator, Expr denominator);
22267
22268} // namespace Halide
22269
22270#endif
22271#ifndef FIND_CALLS_H
22272#define FIND_CALLS_H
22273
22274/** \file
22275 *
22276 * Defines analyses to extract the functions called a function.
22277 */
22278
22279#include <map>
22280#include <string>
22281
22282
22283namespace Halide {
22284namespace Internal {
22285
22286class Function;
22287
22288/** Construct a map from name to Function definition object for all Halide
22289 * functions called directly in the definition of the Function f, including
22290 * in update definitions, update index expressions, and RDom extents. This map
22291 * _does not_ include the Function f, unless it is called recursively by
22292 * itself.
22293 */
22294std::map<std::string, Function> find_direct_calls(Function f);
22295
22296/** Construct a map from name to Function definition object for all Halide
22297 * functions called directly in the definition of the Function f, or
22298 * indirectly in those functions' definitions, recursively. This map always
22299 * _includes_ the Function f.
22300 */
22301std::map<std::string, Function> find_transitive_calls(Function f);
22302
22303/** Find all Functions transitively referenced by f in any way and add
22304 * them to the given map. */
22305void populate_environment(Function f, std::map<std::string, Function> &env);
22306
22307} // namespace Internal
22308} // namespace Halide
22309
22310#endif
22311#ifndef HALIDE_FIND_INTRINSICS_H
22312#define HALIDE_FIND_INTRINSICS_H
22313
22314/** \file
22315 * Tools to replace common patterns with more readily recognizable intrinsics.
22316 */
22317
22318
22319namespace Halide {
22320namespace Internal {
22321
22322/** Implement intrinsics with non-intrinsic using equivalents. */
22323Expr lower_widening_add(const Expr &a, const Expr &b);
22324Expr lower_widening_mul(const Expr &a, const Expr &b);
22325Expr lower_widening_sub(const Expr &a, const Expr &b);
22326Expr lower_widening_shift_left(const Expr &a, const Expr &b);
22327Expr lower_widening_shift_right(const Expr &a, const Expr &b);
22328
22329Expr lower_rounding_shift_left(const Expr &a, const Expr &b);
22330Expr lower_rounding_shift_right(const Expr &a, const Expr &b);
22331
22332Expr lower_saturating_add(const Expr &a, const Expr &b);
22333Expr lower_saturating_sub(const Expr &a, const Expr &b);
22334
22335Expr lower_halving_add(const Expr &a, const Expr &b);
22336Expr lower_halving_sub(const Expr &a, const Expr &b);
22337Expr lower_rounding_halving_add(const Expr &a, const Expr &b);
22338Expr lower_rounding_halving_sub(const Expr &a, const Expr &b);
22339
22340Expr lower_mul_shift_right(const Expr &a, const Expr &b, const Expr &q);
22341Expr lower_rounding_mul_shift_right(const Expr &a, const Expr &b, const Expr &q);
22342
22343/** Replace one of the above ops with equivalent arithmetic. */
22344Expr lower_intrinsic(const Call *op);
22345
22346/** Replace common arithmetic patterns with intrinsics. */
22347Stmt find_intrinsics(const Stmt &s);
22348Expr find_intrinsics(const Expr &e);
22349
22350/** The reverse of find_intrinsics. */
22351Expr lower_intrinsics(const Expr &e);
22352Stmt lower_intrinsics(const Stmt &s);
22353
22354} // namespace Internal
22355} // namespace Halide
22356
22357#endif
22358#ifndef HALIDE_FLATTEN_NESTED_RAMPS_H
22359#define HALIDE_FLATTEN_NESTED_RAMPS_H
22360
22361/** \file
22362 * Defines the lowering pass that flattens nested ramps and broadcasts.
22363 * */
22364
22365
22366namespace Halide {
22367namespace Internal {
22368
22369/** Take a statement/expression and replace nested ramps and broadcasts. */
22370Stmt flatten_nested_ramps(const Stmt &s);
22371Expr flatten_nested_ramps(const Expr &e);
22372
22373} // namespace Internal
22374} // namespace Halide
22375
22376#endif
22377#ifndef HALIDE_FUSE_GPU_THREAD_LOOPS_H
22378#define HALIDE_FUSE_GPU_THREAD_LOOPS_H
22379
22380/** \file
22381 * Defines the lowering pass that fuses and normalizes loops over gpu
22382 * threads to target CUDA, OpenCL, and Metal.
22383 */
22384
22385
22386namespace Halide {
22387namespace Internal {
22388
22389/** Rewrite all GPU loops to have a min of zero. */
22390Stmt zero_gpu_loop_mins(const Stmt &s);
22391
22392/** Converts Halide's GPGPU IR to the OpenCL/CUDA/Metal model. Within
22393 * every loop over gpu block indices, fuse the inner loops over thread
22394 * indices into a single loop (with predication to turn off
22395 * threads). Push if conditions between GPU blocks to the innermost GPU threads.
22396 * Also injects synchronization points as needed, and hoists
22397 * shared allocations at the block level out into a single shared
22398 * memory array, and heap allocations into a slice of a global pool
22399 * allocated outside the kernel. */
22400Stmt fuse_gpu_thread_loops(Stmt s);
22401
22402} // namespace Internal
22403} // namespace Halide
22404
22405#endif
22406#ifndef FUZZ_FLOAT_STORES_H
22407#define FUZZ_FLOAT_STORES_H
22408
22409
22410/** \file
22411 * Defines a lowering pass that messes with floating point stores.
22412 */
22413
22414namespace Halide {
22415namespace Internal {
22416
22417/** On every store of a floating point value, mask off the
22418 * least-significant-bit of the mantissa. We've found that whether or
22419 * not this dramatically changes the output of a pipeline correlates
22420 * very well with whether or not a pipeline will produce very
22421 * different outputs on different architectures (e.g. with and without
22422 * FMA). It's also a useful way to detect bad tests, such as those
22423 * that expect exact floating point equality across platforms. */
22424Stmt fuzz_float_stores(const Stmt &s);
22425
22426} // namespace Internal
22427} // namespace Halide
22428
22429#endif
22430#ifndef HALIDE_GENERATOR_H_
22431#define HALIDE_GENERATOR_H_
22432
22433/** \file
22434 *
22435 * Generator is a class used to encapsulate the building of Funcs in user
22436 * pipelines. A Generator is agnostic to JIT vs AOT compilation; it can be used for
22437 * either purpose, but is especially convenient to use for AOT compilation.
22438 *
22439 * A Generator explicitly declares the Inputs and Outputs associated for a given
22440 * pipeline, and (optionally) separates the code for constructing the outputs from the code from
22441 * scheduling them. For instance:
22442 *
22443 * \code
22444 * class Blur : public Generator<Blur> {
22445 * public:
22446 * Input<Func> input{"input", UInt(16), 2};
22447 * Output<Func> output{"output", UInt(16), 2};
22448 * void generate() {
22449 * blur_x(x, y) = (input(x, y) + input(x+1, y) + input(x+2, y))/3;
22450 * blur_y(x, y) = (blur_x(x, y) + blur_x(x, y+1) + blur_x(x, y+2))/3;
22451 * output(x, y) = blur(x, y);
22452 * }
22453 * void schedule() {
22454 * blur_y.split(y, y, yi, 8).parallel(y).vectorize(x, 8);
22455 * blur_x.store_at(blur_y, y).compute_at(blur_y, yi).vectorize(x, 8);
22456 * }
22457 * private:
22458 * Var x, y, xi, yi;
22459 * Func blur_x, blur_y;
22460 * };
22461 * \endcode
22462 *
22463 * Halide can compile a Generator into the correct pipeline by introspecting these
22464 * values and constructing an appropriate signature based on them.
22465 *
22466 * A Generator provides implementations of two methods:
22467 *
22468 * - generate(), which must fill in all Output Func(s); it may optionally also do scheduling
22469 * if no schedule() method is present.
22470 * - schedule(), which (if present) should contain all scheduling code.
22471 *
22472 * Inputs can be any C++ scalar type:
22473 *
22474 * \code
22475 * Input<float> radius{"radius"};
22476 * Input<int32_t> increment{"increment"};
22477 * \endcode
22478 *
22479 * An Input<Func> is (essentially) like an ImageParam, except that it may (or may
22480 * not) not be backed by an actual buffer, and thus has no defined extents.
22481 *
22482 * \code
22483 * Input<Func> input{"input", Float(32), 2};
22484 * \endcode
22485 *
22486 * You can optionally make the type and/or dimensions of Input<Func> unspecified,
22487 * in which case the value is simply inferred from the actual Funcs passed to them.
22488 * Of course, if you specify an explicit Type or Dimension, we still require the
22489 * input Func to match, or a compilation error results.
22490 *
22491 * \code
22492 * Input<Func> input{ "input", 3 }; // require 3-dimensional Func,
22493 * // but leave Type unspecified
22494 * \endcode
22495 *
22496 * A Generator must explicitly list the output(s) it produces:
22497 *
22498 * \code
22499 * Output<Func> output{"output", Float(32), 2};
22500 * \endcode
22501 *
22502 * You can specify an output that returns a Tuple by specifying a list of Types:
22503 *
22504 * \code
22505 * class Tupler : Generator<Tupler> {
22506 * Input<Func> input{"input", Int(32), 2};
22507 * Output<Func> output{"output", {Float(32), UInt(8)}, 2};
22508 * void generate() {
22509 * Var x, y;
22510 * Expr a = cast<float>(input(x, y));
22511 * Expr b = cast<uint8_t>(input(x, y));
22512 * output(x, y) = Tuple(a, b);
22513 * }
22514 * };
22515 * \endcode
22516 *
22517 * You can also specify Output<X> for any scalar type (except for Handle types);
22518 * this is merely syntactic sugar on top of a zero-dimensional Func, but can be
22519 * quite handy, especially when used with multiple outputs:
22520 *
22521 * \code
22522 * Output<float> sum{"sum"}; // equivalent to Output<Func> {"sum", Float(32), 0}
22523 * \endcode
22524 *
22525 * As with Input<Func>, you can optionally make the type and/or dimensions of an
22526 * Output<Func> unspecified; any unspecified types must be resolved via an
22527 * implicit GeneratorParam in order to use top-level compilation.
22528 *
22529 * You can also declare an *array* of Input or Output, by using an array type
22530 * as the type parameter:
22531 *
22532 * \code
22533 * // Takes exactly 3 images and outputs exactly 3 sums.
22534 * class SumRowsAndColumns : Generator<SumRowsAndColumns> {
22535 * Input<Func[3]> inputs{"inputs", Float(32), 2};
22536 * Input<int32_t[2]> extents{"extents"};
22537 * Output<Func[3]> sums{"sums", Float(32), 1};
22538 * void generate() {
22539 * assert(inputs.size() == sums.size());
22540 * // assume all inputs are same extent
22541 * Expr width = extent[0];
22542 * Expr height = extent[1];
22543 * for (size_t i = 0; i < inputs.size(); ++i) {
22544 * RDom r(0, width, 0, height);
22545 * sums[i]() = 0.f;
22546 * sums[i]() += inputs[i](r.x, r.y);
22547 * }
22548 * }
22549 * };
22550 * \endcode
22551 *
22552 * You can also leave array size unspecified, with some caveats:
22553 * - For ahead-of-time compilation, Inputs must have a concrete size specified
22554 * via a GeneratorParam at build time (e.g., pyramid.size=3)
22555 * - For JIT compilation via a Stub, Inputs array sizes will be inferred
22556 * from the vector passed.
22557 * - For ahead-of-time compilation, Outputs may specify a concrete size
22558 * via a GeneratorParam at build time (e.g., pyramid.size=3), or the
22559 * size can be specified via a resize() method.
22560 *
22561 * \code
22562 * class Pyramid : public Generator<Pyramid> {
22563 * public:
22564 * GeneratorParam<int32_t> levels{"levels", 10};
22565 * Input<Func> input{ "input", Float(32), 2 };
22566 * Output<Func[]> pyramid{ "pyramid", Float(32), 2 };
22567 * void generate() {
22568 * pyramid.resize(levels);
22569 * pyramid[0](x, y) = input(x, y);
22570 * for (int i = 1; i < pyramid.size(); i++) {
22571 * pyramid[i](x, y) = (pyramid[i-1](2*x, 2*y) +
22572 * pyramid[i-1](2*x+1, 2*y) +
22573 * pyramid[i-1](2*x, 2*y+1) +
22574 * pyramid[i-1](2*x+1, 2*y+1))/4;
22575 * }
22576 * }
22577 * };
22578 * \endcode
22579 *
22580 * A Generator can also be customized via compile-time parameters (GeneratorParams),
22581 * which affect code generation.
22582 *
22583 * GeneratorParams, Inputs, and Outputs are (by convention) always
22584 * public and always declared at the top of the Generator class, in the order
22585 *
22586 * \code
22587 * GeneratorParam(s)
22588 * Input<Func>(s)
22589 * Input<non-Func>(s)
22590 * Output<Func>(s)
22591 * \endcode
22592 *
22593 * Note that the Inputs and Outputs will appear in the C function call in the order
22594 * they are declared. All Input<Func> and Output<Func> are represented as halide_buffer_t;
22595 * all other Input<> are the appropriate C++ scalar type. (GeneratorParams are
22596 * always referenced by name, not position, so their order is irrelevant.)
22597 *
22598 * All Inputs and Outputs must have explicit names, and all such names must match
22599 * the regex [A-Za-z][A-Za-z_0-9]* (i.e., essentially a C/C++ variable name, with
22600 * some extra restrictions on underscore use). By convention, the name should match
22601 * the member-variable name.
22602 *
22603 * You can dynamically add Inputs and Outputs to your Generator via adding a
22604 * configure() method; if present, it will be called before generate(). It can
22605 * examine GeneratorParams but it may not examine predeclared Inputs or Outputs;
22606 * the only thing it should do is call add_input<>() and/or add_output<>(), or call
22607 * set_type()/set_dimensions()/set_array_size() on an Input or Output with an unspecified type.
22608 * Added inputs will be appended (in order) after predeclared Inputs but before
22609 * any Outputs; added outputs will be appended after predeclared Outputs.
22610 *
22611 * Note that the pointers returned by add_input() and add_output() are owned
22612 * by the Generator and will remain valid for the Generator's lifetime; user code
22613 * should not attempt to delete or free them.
22614 *
22615 * \code
22616 * class MultiSum : public Generator<MultiSum> {
22617 * public:
22618 * GeneratorParam<int32_t> input_count{"input_count", 10};
22619 * Output<Func> output{ "output", Float(32), 2 };
22620 *
22621 * void configure() {
22622 * for (int i = 0; i < input_count; ++i) {
22623 * extra_inputs.push_back(
22624 * add_input<Func>("input_" + std::to_string(i), Float(32), 2);
22625 * }
22626 * }
22627 *
22628 * void generate() {
22629 * Expr sum = 0.f;
22630 * for (int i = 0; i < input_count; ++i) {
22631 * sum += (*extra_inputs)[i](x, y);
22632 * }
22633 * output(x, y) = sum;
22634 * }
22635 * private:
22636 * std::vector<Input<Func>* extra_inputs;
22637 * };
22638 * \endcode
22639 *
22640 * All Generators have three GeneratorParams that are implicitly provided
22641 * by the base class:
22642 *
22643 * GeneratorParam<Target> target{"target", Target()};
22644 * GeneratorParam<bool> auto_schedule{"auto_schedule", false};
22645 * GeneratorParam<MachineParams> machine_params{"machine_params", MachineParams::generic()};
22646 *
22647 * - 'target' is the Halide::Target for which the Generator is producing code.
22648 * It is read-only during the Generator's lifetime, and must not be modified;
22649 * its value should always be filled in by the calling code: either the Halide
22650 * build system (for ahead-of-time compilation), or ordinary C++ code
22651 * (for JIT compilation).
22652 * - 'auto_schedule' indicates whether the auto-scheduler should be run for this
22653 * Generator:
22654 * - if 'false', the Generator should schedule its Funcs as it sees fit.
22655 * - if 'true', the Generator should only provide estimate()s for its Funcs,
22656 * and not call any other scheduling methods.
22657 * - 'machine_params' is only used if auto_schedule is true; it is ignored
22658 * if auto_schedule is false. It provides details about the machine architecture
22659 * being targeted which may be used to enhance the automatically-generated
22660 * schedule.
22661 *
22662 * Generators are added to a global registry to simplify AOT build mechanics; this
22663 * is done by simply using the HALIDE_REGISTER_GENERATOR macro at global scope:
22664 *
22665 * \code
22666 * HALIDE_REGISTER_GENERATOR(ExampleGen, jit_example)
22667 * \endcode
22668 *
22669 * The registered name of the Generator is provided must match the same rules as
22670 * Input names, above.
22671 *
22672 * Note that the class name of the generated Stub class will match the registered
22673 * name by default; if you want to vary it (typically, to include namespaces),
22674 * you can add it as an optional third argument:
22675 *
22676 * \code
22677 * HALIDE_REGISTER_GENERATOR(ExampleGen, jit_example, SomeNamespace::JitExampleStub)
22678 * \endcode
22679 *
22680 * Note that a Generator is always executed with a specific Target assigned to it,
22681 * that you can access via the get_target() method. (You should *not* use the
22682 * global get_target_from_environment(), etc. methods provided in Target.h)
22683 *
22684 * (Note that there are older variations of Generator that differ from what's
22685 * documented above; these are still supported but not described here. See
22686 * https://github.com/halide/Halide/wiki/Old-Generator-Documentation for
22687 * more information.)
22688 */
22689
22690#include <algorithm>
22691#include <functional>
22692#include <iterator>
22693#include <limits>
22694#include <memory>
22695#include <mutex>
22696#include <set>
22697#include <sstream>
22698#include <string>
22699#include <type_traits>
22700#include <utility>
22701#include <vector>
22702
22703#ifndef HALIDE_IMAGE_PARAM_H
22704#define HALIDE_IMAGE_PARAM_H
22705
22706/** \file
22707 *
22708 * Classes for declaring image parameters to halide pipelines
22709 */
22710
22711#include <utility>
22712
22713#ifndef HALIDE_OUTPUT_IMAGE_PARAM_H
22714#define HALIDE_OUTPUT_IMAGE_PARAM_H
22715
22716/** \file
22717 *
22718 * Classes for declaring output image parameters to halide pipelines
22719 */
22720
22721
22722namespace Halide {
22723
22724/** A handle on the output buffer of a pipeline. Used to make static
22725 * promises about the output size and stride. */
22726class OutputImageParam {
22727protected:
22728 friend class Func;
22729
22730 /** A reference-counted handle on the internal parameter object */
22731 Internal::Parameter param;
22732
22733 /** Is this an input or an output? OutputImageParam is the base class for both. */
22734 Argument::Kind kind = Argument::InputScalar;
22735
22736 /** If Input: Func representation of the ImageParam.
22737 * If Output: Func that creates this OutputImageParam.
22738 */
22739 Func func;
22740
22741 void add_implicit_args_if_placeholder(std::vector<Expr> &args,
22742 Expr last_arg,
22743 int total_args,
22744 bool *placeholder_seen) const;
22745
22746 /** Construct an OutputImageParam that wraps an Internal Parameter object. */
22747 OutputImageParam(const Internal::Parameter &p, Argument::Kind k, Func f);
22748
22749public:
22750 /** Construct a null image parameter handle. */
22751 OutputImageParam() = default;
22752
22753 /** Get the name of this Param */
22754 const std::string &name() const;
22755
22756 /** Get the type of the image data this Param refers to */
22757 Type type() const;
22758
22759 /** Is this parameter handle non-nullptr */
22760 bool defined() const;
22761
22762 /** Get a handle on one of the dimensions for the purposes of
22763 * inspecting or constraining its min, extent, or stride. */
22764 Internal::Dimension dim(int i);
22765
22766 /** Get a handle on one of the dimensions for the purposes of
22767 * inspecting its min, extent, or stride. */
22768 Internal::Dimension dim(int i) const;
22769
22770 /** Get the alignment of the host pointer in bytes. Defaults to
22771 * the size of type. */
22772 int host_alignment() const;
22773
22774 /** Set the expected alignment of the host pointer in bytes. */
22775 OutputImageParam &set_host_alignment(int);
22776
22777 /** Get the dimensionality of this image parameter */
22778 int dimensions() const;
22779
22780 /** Get an expression giving the minimum coordinate in dimension 0, which
22781 * by convention is the coordinate of the left edge of the image */
22782 Expr left() const;
22783
22784 /** Get an expression giving the maximum coordinate in dimension 0, which
22785 * by convention is the coordinate of the right edge of the image */
22786 Expr right() const;
22787
22788 /** Get an expression giving the minimum coordinate in dimension 1, which
22789 * by convention is the top of the image */
22790 Expr top() const;
22791
22792 /** Get an expression giving the maximum coordinate in dimension 1, which
22793 * by convention is the bottom of the image */
22794 Expr bottom() const;
22795
22796 /** Get an expression giving the extent in dimension 0, which by
22797 * convention is the width of the image */
22798 Expr width() const;
22799
22800 /** Get an expression giving the extent in dimension 1, which by
22801 * convention is the height of the image */
22802 Expr height() const;
22803
22804 /** Get an expression giving the extent in dimension 2, which by
22805 * convention is the channel-count of the image */
22806 Expr channels() const;
22807
22808 /** Get at the internal parameter object representing this ImageParam. */
22809 Internal::Parameter parameter() const;
22810
22811 /** Construct the appropriate argument matching this parameter,
22812 * for the purpose of generating the right type signature when
22813 * statically compiling halide pipelines. */
22814 operator Argument() const;
22815
22816 /** Using a param as the argument to an external stage treats it
22817 * as an Expr */
22818 operator ExternFuncArgument() const;
22819
22820 /** Set (min, extent) estimates for all dimensions in the ImageParam
22821 * at once; this is equivalent to calling `dim(n).set_estimate(min, extent)`
22822 * repeatedly, but slightly terser. The size of the estimates vector
22823 * must match the dimensionality of the ImageParam. */
22824 OutputImageParam &set_estimates(const Region &estimates);
22825
22826 /** Set the desired storage type for this parameter. Only useful
22827 * for MemoryType::GPUTexture at present */
22828 OutputImageParam &store_in(MemoryType type);
22829};
22830
22831} // namespace Halide
22832
22833#endif
22834
22835namespace Halide {
22836
22837namespace Internal {
22838template<typename T2>
22839class GeneratorInput_Buffer;
22840}
22841
22842/** An Image parameter to a halide pipeline. E.g., the input image. */
22843class ImageParam : public OutputImageParam {
22844 template<typename T2>
22845 friend class ::Halide::Internal::GeneratorInput_Buffer;
22846
22847 // Only for use of Generator
22848 ImageParam(const Internal::Parameter &p, Func f)
22849 : OutputImageParam(p, Argument::InputBuffer, std::move(f)) {
22850 }
22851
22852 /** Helper function to initialize the Func representation of this ImageParam. */
22853 Func create_func() const;
22854
22855public:
22856 /** Construct a nullptr image parameter handle. */
22857 ImageParam() = default;
22858
22859 /** Construct an image parameter of the given type and
22860 * dimensionality, with an auto-generated unique name. */
22861 ImageParam(Type t, int d);
22862
22863 /** Construct an image parameter of the given type and
22864 * dimensionality, with the given name */
22865 ImageParam(Type t, int d, const std::string &n);
22866
22867 /** Bind an Image to this ImageParam. Only relevant for jitting */
22868 // @{
22869 void set(const Buffer<> &im);
22870 // @}
22871
22872 /** Get a reference to the Buffer bound to this ImageParam. Only relevant for jitting. */
22873 // @{
22874 Buffer<> get() const;
22875 // @}
22876
22877 /** Unbind any bound Buffer */
22878 void reset();
22879
22880 /** Construct an expression which loads from this image
22881 * parameter. The location is extended with enough implicit
22882 * variables to match the dimensionality of the image
22883 * (see \ref Var::implicit)
22884 */
22885 // @{
22886 template<typename... Args>
22887 HALIDE_NO_USER_CODE_INLINE Expr operator()(Args &&...args) const {
22888 return func(std::forward<Args>(args)...);
22889 }
22890 Expr operator()(std::vector<Expr>) const;
22891 Expr operator()(std::vector<Var>) const;
22892 // @}
22893
22894 /** Return the intrinsic Func representation of this ImageParam. This allows
22895 * an ImageParam to be implicitly converted to a Func.
22896 *
22897 * Note that we use implicit vars to name the dimensions of Funcs associated
22898 * with the ImageParam: both its internal Func representation and wrappers
22899 * (See \ref ImageParam::in). For example, to unroll the first and second
22900 * dimensions of the associated Func by a factor of 2, we would do the following:
22901 \code
22902 func.unroll(_0, 2).unroll(_1, 2);
22903 \endcode
22904 * '_0' represents the first dimension of the Func, while _1 represents the
22905 * second dimension of the Func.
22906 */
22907 operator Func() const;
22908
22909 /** Creates and returns a new Func that wraps this ImageParam. During
22910 * compilation, Halide will replace calls to this ImageParam with calls
22911 * to the wrapper as appropriate. If this ImageParam is already wrapped
22912 * for use in some Func, it will return the existing wrapper.
22913 *
22914 * For example, img.in(g) would rewrite a pipeline like this:
22915 \code
22916 ImageParam img(Int(32), 2);
22917 Func g;
22918 g(x, y) = ... img(x, y) ...
22919 \endcode
22920 * into a pipeline like this:
22921 \code
22922 ImageParam img(Int(32), 2);
22923 Func img_wrap, g;
22924 img_wrap(x, y) = img(x, y);
22925 g(x, y) = ... img_wrap(x, y) ...
22926 \endcode
22927 *
22928 * This has a variety of uses. One use case is to stage loads from an
22929 * ImageParam via some intermediate buffer (e.g. on the stack or in shared
22930 * GPU memory).
22931 *
22932 * The following example illustrates how you would use the 'in()' directive
22933 * to stage loads from an ImageParam into the GPU shared memory:
22934 \code
22935 ImageParam img(Int(32), 2);
22936 output(x, y) = img(y, x);
22937 Var tx, ty;
22938 output.compute_root().gpu_tile(x, y, tx, ty, 8, 8);
22939 img.in().compute_at(output, x).unroll(_0, 2).unroll(_1, 2).gpu_threads(_0, _1);
22940 \endcode
22941 *
22942 * Note that we use implicit vars to name the dimensions of the wrapper Func.
22943 * See \ref Func::in for more possible use cases of the 'in()' directive.
22944 */
22945 // @{
22946 Func in(const Func &f);
22947 Func in(const std::vector<Func> &fs);
22948 Func in();
22949 // @}
22950
22951 /** Trace all loads from this ImageParam by emitting calls to halide_trace. */
22952 void trace_loads();
22953
22954 /** Add a trace tag to this ImageParam's Func. */
22955 ImageParam &add_trace_tag(const std::string &trace_tag);
22956};
22957
22958} // namespace Halide
22959
22960#endif
22961#ifndef HALIDE_INTROSPECTION_H
22962#define HALIDE_INTROSPECTION_H
22963
22964#include <cstdint>
22965#include <iostream>
22966#include <string>
22967
22968/** \file
22969 *
22970 * Defines methods for introspecting in C++. Relies on DWARF debugging
22971 * metadata, so the compilation unit that uses this must be compiled
22972 * with -g.
22973 */
22974
22975namespace Halide {
22976namespace Internal {
22977
22978namespace Introspection {
22979/** Get the name of a stack variable from its address. The stack
22980 * variable must be in a compilation unit compiled with -g to
22981 * work. The expected type helps distinguish between variables at the
22982 * same address, e.g a class instance vs its first member. */
22983std::string get_variable_name(const void *, const std::string &expected_type);
22984
22985/** Register an untyped heap object. Derive type information from an
22986 * introspectable pointer to a pointer to a global object of the same
22987 * type. Not thread-safe. */
22988void register_heap_object(const void *obj, size_t size, const void *helper);
22989
22990/** Deregister a heap object. Not thread-safe. */
22991void deregister_heap_object(const void *obj, size_t size);
22992
22993/** Dump the contents of the stack frame of the calling function. Used
22994 * for debugging stack frame sizes inside the compiler. Returns
22995 * whether or not it was able to find the relevant debug
22996 * information. */
22997bool dump_stack_frame();
22998
22999#define HALIDE_DUMP_STACK_FRAME \
23000 { \
23001 static bool check = Halide::Internal::Introspection::dump_stack_frame(); \
23002 (void)check; \
23003 }
23004
23005/** Return the address of a global with type T *. Call this to
23006 * generate something to pass as the last argument to
23007 * register_heap_object.
23008 */
23009template<typename T>
23010const void *get_introspection_helper() {
23011 static T *introspection_helper = nullptr;
23012 return &introspection_helper;
23013}
23014
23015/** Get the source location in the call stack, skipping over calls in
23016 * the Halide namespace. */
23017std::string get_source_location();
23018
23019// This gets called automatically by anyone who includes Halide.h by
23020// the code below. It tests if this functionality works for the given
23021// compilation unit, and disables it if not.
23022void test_compilation_unit(bool (*test)(bool (*)(const void *, const std::string &)),
23023 bool (*test_a)(const void *, const std::string &),
23024 void (*calib)());
23025} // namespace Introspection
23026
23027} // namespace Internal
23028} // namespace Halide
23029
23030// This code verifies that introspection is working before relying on
23031// it. The definitions must appear in Halide.h, but they should not
23032// appear in libHalide itself. They're defined as static so that clients
23033// can include Halide.h multiple times without link errors.
23034#ifndef COMPILING_HALIDE
23035
23036namespace Halide {
23037namespace Internal {
23038static bool check_introspection(const void *var, const std::string &type,
23039 const std::string &correct_name,
23040 const std::string &correct_file, int line) {
23041 std::string correct_loc = correct_file + ":" + std::to_string(line);
23042 std::string loc = Introspection::get_source_location();
23043 std::string name = Introspection::get_variable_name(var, type);
23044 return name == correct_name && loc == correct_loc;
23045}
23046} // namespace Internal
23047} // namespace Halide
23048
23049namespace HalideIntrospectionCanary {
23050
23051// A function that acts as a signpost. By taking it's address and
23052// comparing it to the program counter listed in the debugging info,
23053// we can calibrate for any offset between the debugging info and the
23054// actual memory layout where the code was loaded.
23055static void offset_marker() {
23056 std::cerr << "You should not have called this function\n";
23057}
23058
23059struct A {
23060 int an_int;
23061
23062 class B {
23063 int private_member = 17;
23064
23065 public:
23066 float a_float;
23067 A *parent;
23068 B() {
23069 a_float = private_member * 2.0f;
23070 }
23071 };
23072
23073 B a_b;
23074
23075 A() {
23076 a_b.parent = this;
23077 }
23078
23079 bool test(const std::string &my_name);
23080};
23081
23082static bool test_a(const void *a_ptr, const std::string &my_name) {
23083 const A *a = (const A *)a_ptr;
23084 bool success = true;
23085 success &= Halide::Internal::check_introspection(&a->an_int, "int", my_name + ".an_int", __FILE__, __LINE__);
23086 success &= Halide::Internal::check_introspection(&a->a_b, "HalideIntrospectionCanary::A::B", my_name + ".a_b", __FILE__, __LINE__);
23087 success &= Halide::Internal::check_introspection(&a->a_b.parent, "HalideIntrospectionCanary::A \\*", my_name + ".a_b.parent", __FILE__, __LINE__);
23088 success &= Halide::Internal::check_introspection(&a->a_b.a_float, "float", my_name + ".a_b.a_float", __FILE__, __LINE__);
23089 success &= Halide::Internal::check_introspection(a->a_b.parent, "HalideIntrospectionCanary::A", my_name, __FILE__, __LINE__);
23090 return success;
23091}
23092
23093static bool test(bool (*f)(const void *, const std::string &)) {
23094 A a1, a2;
23095
23096 // Call via pointer to prevent inlining.
23097 return f(&a1, "a1") && f(&a2, "a2");
23098}
23099
23100// Run the tests, and calibrate for the PC offset at static initialization time.
23101namespace {
23102struct TestCompilationUnit {
23103 TestCompilationUnit() {
23104 Halide::Internal::Introspection::test_compilation_unit(&test, &test_a, &offset_marker);
23105 }
23106};
23107} // namespace
23108
23109static TestCompilationUnit test_object;
23110
23111} // namespace HalideIntrospectionCanary
23112
23113#endif
23114
23115#endif
23116#ifndef HALIDE_OBJECT_INSTANCE_REGISTRY_H
23117#define HALIDE_OBJECT_INSTANCE_REGISTRY_H
23118
23119/** \file
23120 *
23121 * Provides a single global registry of Generators, GeneratorParams,
23122 * and Params indexed by this pointer. This is used for finding the
23123 * parameters inside of a Generator. NOTE: this is threadsafe only
23124 * if you are compiling with C++11 enabled.
23125 */
23126
23127#include <cstddef>
23128#include <cstdint>
23129
23130#include <map>
23131#include <mutex>
23132#include <vector>
23133
23134namespace Halide {
23135namespace Internal {
23136
23137class ObjectInstanceRegistry {
23138public:
23139 enum Kind {
23140 Invalid,
23141 Generator,
23142 GeneratorParam,
23143 GeneratorInput,
23144 GeneratorOutput,
23145 FilterParam
23146 };
23147
23148 /** Add an instance to the registry. The size may be 0 for Param Kinds,
23149 * but not for Generator. subject_ptr is the value actually associated
23150 * with this instance; it is usually (but not necessarily) the same
23151 * as this_ptr. Assert if this_ptr is already registered.
23152 *
23153 * If 'this' is directly heap allocated (not a member of a
23154 * heap-allocated object) and you want the introspection subsystem
23155 * to know about it and its members, set the introspection_helper
23156 * argument to a pointer to a global variable with the same true
23157 * type as 'this'. For example:
23158 *
23159 * MyObject *obj = new MyObject;
23160 * static MyObject *introspection_helper = nullptr;
23161 * register_instance(obj, sizeof(MyObject), kind, obj, &introspection_helper);
23162 *
23163 * I.e. introspection_helper should be a pointer to a pointer to
23164 * an object instance. The inner pointer can be null. The
23165 * introspection subsystem will then assume this new object is of
23166 * the matching type, which will help its members deduce their
23167 * names on construction.
23168 */
23169 static void register_instance(void *this_ptr, size_t size, Kind kind, void *subject_ptr,
23170 const void *introspection_helper);
23171
23172 /** Remove an instance from the registry. Assert if not found.
23173 */
23174 static void unregister_instance(void *this_ptr);
23175
23176 /** Returns the list of subject pointers for objects that have
23177 * been directly registered within the given range. If there is
23178 * another containing object inside the range, instances within
23179 * that object are skipped.
23180 */
23181 static std::vector<void *> instances_in_range(void *start, size_t size, Kind kind);
23182
23183private:
23184 static ObjectInstanceRegistry &get_registry();
23185
23186 struct InstanceInfo {
23187 void *subject_ptr = nullptr; // May be different from the this_ptr in the key
23188 size_t size = 0; // May be 0 for params
23189 Kind kind = Invalid;
23190 bool registered_for_introspection = false;
23191
23192 InstanceInfo() = default;
23193 InstanceInfo(size_t size, Kind kind, void *subject_ptr, bool registered_for_introspection)
23194 : subject_ptr(subject_ptr), size(size), kind(kind), registered_for_introspection(registered_for_introspection) {
23195 }
23196 };
23197
23198 std::mutex mutex;
23199 std::map<uintptr_t, InstanceInfo> instances;
23200
23201 ObjectInstanceRegistry() = default;
23202
23203public:
23204 ObjectInstanceRegistry(const ObjectInstanceRegistry &) = delete;
23205 ObjectInstanceRegistry &operator=(const ObjectInstanceRegistry &) = delete;
23206 ObjectInstanceRegistry(ObjectInstanceRegistry &&) = delete;
23207 ObjectInstanceRegistry &operator=(ObjectInstanceRegistry &&) = delete;
23208};
23209
23210} // namespace Internal
23211} // namespace Halide
23212
23213#endif // HALIDE_OBJECT_INSTANCE_REGISTRY_H
23214
23215namespace Halide {
23216
23217template<typename T>
23218class Buffer;
23219
23220namespace Internal {
23221
23222void generator_test();
23223
23224/**
23225 * ValueTracker is an internal utility class that attempts to track and flag certain
23226 * obvious Stub-related errors at Halide compile time: it tracks the constraints set
23227 * on any Parameter-based argument (i.e., Input<Buffer> and Output<Buffer>) to
23228 * ensure that incompatible values aren't set.
23229 *
23230 * e.g.: if a Generator A requires stride[0] == 1,
23231 * and Generator B uses Generator A via stub, but requires stride[0] == 4,
23232 * we should be able to detect this at Halide compilation time, and fail immediately,
23233 * rather than producing code that fails at runtime and/or runs slowly due to
23234 * vectorization being unavailable.
23235 *
23236 * We do this by tracking the active values at entrance and exit to all user-provided
23237 * Generator methods (build()/generate()/schedule()); if we ever find more than two unique
23238 * values active, we know we have a potential conflict. ("two" here because the first
23239 * value is the default value for a given constraint.)
23240 *
23241 * Note that this won't catch all cases:
23242 * -- JIT compilation has no way to check for conflicts at the top-level
23243 * -- constraints that match the default value (e.g. if dim(0).set_stride(1) is the
23244 * first value seen by the tracker) will be ignored, so an explicit requirement set
23245 * this way can be missed
23246 *
23247 * Nevertheless, this is likely to be much better than nothing when composing multiple
23248 * layers of Stubs in a single fused result.
23249 */
23250class ValueTracker {
23251private:
23252 std::map<std::string, std::vector<std::vector<Expr>>> values_history;
23253 const size_t max_unique_values;
23254
23255public:
23256 explicit ValueTracker(size_t max_unique_values = 2)
23257 : max_unique_values(max_unique_values) {
23258 }
23259 void track_values(const std::string &name, const std::vector<Expr> &values);
23260};
23261
23262std::vector<Expr> parameter_constraints(const Parameter &p);
23263
23264template<typename T>
23265HALIDE_NO_USER_CODE_INLINE std::string enum_to_string(const std::map<std::string, T> &enum_map, const T &t) {
23266 for (const auto &key_value : enum_map) {
23267 if (t == key_value.second) {
23268 return key_value.first;
23269 }
23270 }
23271 user_error << "Enumeration value not found.\n";
23272 return "";
23273}
23274
23275template<typename T>
23276T enum_from_string(const std::map<std::string, T> &enum_map, const std::string &s) {
23277 auto it = enum_map.find(s);
23278 user_assert(it != enum_map.end()) << "Enumeration value not found: " << s << "\n";
23279 return it->second;
23280}
23281
23282extern const std::map<std::string, Halide::Type> &get_halide_type_enum_map();
23283inline std::string halide_type_to_enum_string(const Type &t) {
23284 return enum_to_string(get_halide_type_enum_map(), t);
23285}
23286
23287// Convert a Halide Type into a string representation of its C source.
23288// e.g., Int(32) -> "Halide::Int(32)"
23289std::string halide_type_to_c_source(const Type &t);
23290
23291// Convert a Halide Type into a string representation of its C Source.
23292// e.g., Int(32) -> "int32_t"
23293std::string halide_type_to_c_type(const Type &t);
23294
23295/** generate_filter_main() is a convenient wrapper for GeneratorRegistry::create() +
23296 * compile_to_files(); it can be trivially wrapped by a "real" main() to produce a
23297 * command-line utility for ahead-of-time filter compilation. */
23298int generate_filter_main(int argc, char **argv, std::ostream &cerr);
23299
23300// select_type<> is to std::conditional as switch is to if:
23301// it allows a multiway compile-time type definition via the form
23302//
23303// select_type<cond<condition1, type1>,
23304// cond<condition2, type2>,
23305// ....
23306// cond<conditionN, typeN>>::type
23307//
23308// Note that the conditions are evaluated in order; the first evaluating to true
23309// is chosen.
23310//
23311// Note that if no conditions evaluate to true, the resulting type is illegal
23312// and will produce a compilation error. (You can provide a default by simply
23313// using cond<true, SomeType> as the final entry.)
23314template<bool B, typename T>
23315struct cond {
23316 static constexpr bool value = B;
23317 using type = T;
23318};
23319
23320template<typename First, typename... Rest>
23321struct select_type : std::conditional<First::value, typename First::type, typename select_type<Rest...>::type> {};
23322
23323template<typename First>
23324struct select_type<First> { using type = typename std::conditional<First::value, typename First::type, void>::type; };
23325
23326class GeneratorBase;
23327class GeneratorParamInfo;
23328
23329class GeneratorParamBase {
23330public:
23331 explicit GeneratorParamBase(const std::string &name);
23332 virtual ~GeneratorParamBase();
23333
23334 inline const std::string &name() const {
23335 return name_;
23336 }
23337
23338 // overload the set() function to call the right virtual method based on type.
23339 // This allows us to attempt to set a GeneratorParam via a
23340 // plain C++ type, even if we don't know the specific templated
23341 // subclass. Attempting to set the wrong type will assert.
23342 // Notice that there is no typed setter for Enums, for obvious reasons;
23343 // setting enums in an unknown type must fallback to using set_from_string.
23344 //
23345 // It's always a bit iffy to use macros for this, but IMHO it clarifies the situation here.
23346#define HALIDE_GENERATOR_PARAM_TYPED_SETTER(TYPE) \
23347 virtual void set(const TYPE &new_value) = 0;
23348
23349 HALIDE_GENERATOR_PARAM_TYPED_SETTER(bool)
23350 HALIDE_GENERATOR_PARAM_TYPED_SETTER(int8_t)
23351 HALIDE_GENERATOR_PARAM_TYPED_SETTER(int16_t)
23352 HALIDE_GENERATOR_PARAM_TYPED_SETTER(int32_t)
23353 HALIDE_GENERATOR_PARAM_TYPED_SETTER(int64_t)
23354 HALIDE_GENERATOR_PARAM_TYPED_SETTER(uint8_t)
23355 HALIDE_GENERATOR_PARAM_TYPED_SETTER(uint16_t)
23356 HALIDE_GENERATOR_PARAM_TYPED_SETTER(uint32_t)
23357 HALIDE_GENERATOR_PARAM_TYPED_SETTER(uint64_t)
23358 HALIDE_GENERATOR_PARAM_TYPED_SETTER(float)
23359 HALIDE_GENERATOR_PARAM_TYPED_SETTER(double)
23360 HALIDE_GENERATOR_PARAM_TYPED_SETTER(Target)
23361 HALIDE_GENERATOR_PARAM_TYPED_SETTER(MachineParams)
23362 HALIDE_GENERATOR_PARAM_TYPED_SETTER(Type)
23363 HALIDE_GENERATOR_PARAM_TYPED_SETTER(LoopLevel)
23364
23365#undef HALIDE_GENERATOR_PARAM_TYPED_SETTER
23366
23367 // Add overloads for string and char*
23368 void set(const std::string &new_value) {
23369 set_from_string(new_value);
23370 }
23371 void set(const char *new_value) {
23372 set_from_string(std::string(new_value));
23373 }
23374
23375protected:
23376 friend class GeneratorBase;
23377 friend class GeneratorParamInfo;
23378 friend class StubEmitter;
23379
23380 void check_value_readable() const;
23381 void check_value_writable() const;
23382
23383 // All GeneratorParams are settable from string.
23384 virtual void set_from_string(const std::string &value_string) = 0;
23385
23386 virtual std::string call_to_string(const std::string &v) const = 0;
23387 virtual std::string get_c_type() const = 0;
23388
23389 virtual std::string get_type_decls() const {
23390 return "";
23391 }
23392
23393 virtual std::string get_default_value() const = 0;
23394
23395 virtual bool is_synthetic_param() const {
23396 return false;
23397 }
23398
23399 virtual bool is_looplevel_param() const {
23400 return false;
23401 }
23402
23403 void fail_wrong_type(const char *type);
23404
23405private:
23406 const std::string name_;
23407
23408 // Generator which owns this GeneratorParam. Note that this will be null
23409 // initially; the GeneratorBase itself will set this field when it initially
23410 // builds its info about params. However, since it (generally) isn't
23411 // appropriate for GeneratorParam<> to be declared outside of a Generator,
23412 // all reasonable non-testing code should expect this to be non-null.
23413 GeneratorBase *generator{nullptr};
23414
23415public:
23416 GeneratorParamBase(const GeneratorParamBase &) = delete;
23417 GeneratorParamBase &operator=(const GeneratorParamBase &) = delete;
23418 GeneratorParamBase(GeneratorParamBase &&) = delete;
23419 GeneratorParamBase &operator=(GeneratorParamBase &&) = delete;
23420};
23421
23422// This is strictly some syntactic sugar to suppress certain compiler warnings.
23423template<typename FROM, typename TO>
23424struct Convert {
23425 template<typename TO2 = TO, typename std::enable_if<!std::is_same<TO2, bool>::value>::type * = nullptr>
23426 inline static TO2 value(const FROM &from) {
23427 return static_cast<TO2>(from);
23428 }
23429
23430 template<typename TO2 = TO, typename std::enable_if<std::is_same<TO2, bool>::value>::type * = nullptr>
23431 inline static TO2 value(const FROM &from) {
23432 return from != 0;
23433 }
23434};
23435
23436template<typename T>
23437class GeneratorParamImpl : public GeneratorParamBase {
23438public:
23439 using type = T;
23440
23441 GeneratorParamImpl(const std::string &name, const T &value)
23442 : GeneratorParamBase(name), value_(value) {
23443 }
23444
23445 T value() const {
23446 this->check_value_readable();
23447 return value_;
23448 }
23449
23450 operator T() const {
23451 return this->value();
23452 }
23453
23454 operator Expr() const {
23455 return make_const(type_of<T>(), this->value());
23456 }
23457
23458#define HALIDE_GENERATOR_PARAM_TYPED_SETTER(TYPE) \
23459 void set(const TYPE &new_value) override { \
23460 typed_setter_impl<TYPE>(new_value, #TYPE); \
23461 }
23462
23463 HALIDE_GENERATOR_PARAM_TYPED_SETTER(bool)
23464 HALIDE_GENERATOR_PARAM_TYPED_SETTER(int8_t)
23465 HALIDE_GENERATOR_PARAM_TYPED_SETTER(int16_t)
23466 HALIDE_GENERATOR_PARAM_TYPED_SETTER(int32_t)
23467 HALIDE_GENERATOR_PARAM_TYPED_SETTER(int64_t)
23468 HALIDE_GENERATOR_PARAM_TYPED_SETTER(uint8_t)
23469 HALIDE_GENERATOR_PARAM_TYPED_SETTER(uint16_t)
23470 HALIDE_GENERATOR_PARAM_TYPED_SETTER(uint32_t)
23471 HALIDE_GENERATOR_PARAM_TYPED_SETTER(uint64_t)
23472 HALIDE_GENERATOR_PARAM_TYPED_SETTER(float)
23473 HALIDE_GENERATOR_PARAM_TYPED_SETTER(double)
23474 HALIDE_GENERATOR_PARAM_TYPED_SETTER(Target)
23475 HALIDE_GENERATOR_PARAM_TYPED_SETTER(MachineParams)
23476 HALIDE_GENERATOR_PARAM_TYPED_SETTER(Type)
23477 HALIDE_GENERATOR_PARAM_TYPED_SETTER(LoopLevel)
23478
23479#undef HALIDE_GENERATOR_PARAM_TYPED_SETTER
23480
23481 // Overload for std::string.
23482 void set(const std::string &new_value) {
23483 check_value_writable();
23484 value_ = new_value;
23485 }
23486
23487protected:
23488 virtual void set_impl(const T &new_value) {
23489 check_value_writable();
23490 value_ = new_value;
23491 }
23492
23493 // Needs to be protected to allow GeneratorParam<LoopLevel>::set() override
23494 T value_;
23495
23496private:
23497 // If FROM->T is not legal, fail
23498 template<typename FROM, typename std::enable_if<
23499 !std::is_convertible<FROM, T>::value>::type * = nullptr>
23500 HALIDE_ALWAYS_INLINE void typed_setter_impl(const FROM &, const char *msg) {
23501 fail_wrong_type(msg);
23502 }
23503
23504 // If FROM and T are identical, just assign
23505 template<typename FROM, typename std::enable_if<
23506 std::is_same<FROM, T>::value>::type * = nullptr>
23507 HALIDE_ALWAYS_INLINE void typed_setter_impl(const FROM &value, const char *msg) {
23508 check_value_writable();
23509 value_ = value;
23510 }
23511
23512 // If both FROM->T and T->FROM are legal, ensure it's lossless
23513 template<typename FROM, typename std::enable_if<
23514 !std::is_same<FROM, T>::value &&
23515 std::is_convertible<FROM, T>::value &&
23516 std::is_convertible<T, FROM>::value>::type * = nullptr>
23517 HALIDE_ALWAYS_INLINE void typed_setter_impl(const FROM &value, const char *msg) {
23518 check_value_writable();
23519 const T t = Convert<FROM, T>::value(value);
23520 const FROM value2 = Convert<T, FROM>::value(t);
23521 if (value2 != value) {
23522 fail_wrong_type(msg);
23523 }
23524 value_ = t;
23525 }
23526
23527 // If FROM->T is legal but T->FROM is not, just assign
23528 template<typename FROM, typename std::enable_if<
23529 !std::is_same<FROM, T>::value &&
23530 std::is_convertible<FROM, T>::value &&
23531 !std::is_convertible<T, FROM>::value>::type * = nullptr>
23532 HALIDE_ALWAYS_INLINE void typed_setter_impl(const FROM &value, const char *msg) {
23533 check_value_writable();
23534 value_ = value;
23535 }
23536};
23537
23538// Stubs for type-specific implementations of GeneratorParam, to avoid
23539// many complex enable_if<> statements that were formerly spread through the
23540// implementation. Note that not all of these need to be templated classes,
23541// (e.g. for GeneratorParam_Target, T == Target always), but are declared
23542// that way for symmetry of declaration.
23543template<typename T>
23544class GeneratorParam_Target : public GeneratorParamImpl<T> {
23545public:
23546 GeneratorParam_Target(const std::string &name, const T &value)
23547 : GeneratorParamImpl<T>(name, value) {
23548 }
23549
23550 void set_from_string(const std::string &new_value_string) override {
23551 this->set(Target(new_value_string));
23552 }
23553
23554 std::string get_default_value() const override {
23555 return this->value().to_string();
23556 }
23557
23558 std::string call_to_string(const std::string &v) const override {
23559 std::ostringstream oss;
23560 oss << v << ".to_string()";
23561 return oss.str();
23562 }
23563
23564 std::string get_c_type() const override {
23565 return "Target";
23566 }
23567};
23568
23569template<typename T>
23570class GeneratorParam_MachineParams : public GeneratorParamImpl<T> {
23571public:
23572 GeneratorParam_MachineParams(const std::string &name, const T &value)
23573 : GeneratorParamImpl<T>(name, value) {
23574 }
23575
23576 void set_from_string(const std::string &new_value_string) override {
23577 this->set(MachineParams(new_value_string));
23578 }
23579
23580 std::string get_default_value() const override {
23581 return this->value().to_string();
23582 }
23583
23584 std::string call_to_string(const std::string &v) const override {
23585 std::ostringstream oss;
23586 oss << v << ".to_string()";
23587 return oss.str();
23588 }
23589
23590 std::string get_c_type() const override {
23591 return "MachineParams";
23592 }
23593};
23594
23595class GeneratorParam_LoopLevel : public GeneratorParamImpl<LoopLevel> {
23596public:
23597 GeneratorParam_LoopLevel(const std::string &name, const LoopLevel &value)
23598 : GeneratorParamImpl<LoopLevel>(name, value) {
23599 }
23600
23601 using GeneratorParamImpl<LoopLevel>::set;
23602
23603 void set(const LoopLevel &value) override {
23604 // Don't call check_value_writable(): It's OK to set a LoopLevel after generate().
23605 // check_value_writable();
23606
23607 // This looks odd, but is deliberate:
23608
23609 // First, mutate the existing contents to match the value passed in,
23610 // so that any existing usage of the LoopLevel now uses the newer value.
23611 // (Strictly speaking, this is really only necessary if this method
23612 // is called after generate(): before generate(), there is no usage
23613 // to be concerned with.)
23614 value_.set(value);
23615
23616 // Then, reset the value itself so that it points to the same LoopLevelContents
23617 // as the value passed in. (Strictly speaking, this is really only
23618 // useful if this method is called before generate(): afterwards, it's
23619 // too late to alter the code to refer to a different LoopLevelContents.)
23620 value_ = value;
23621 }
23622
23623 void set_from_string(const std::string &new_value_string) override {
23624 if (new_value_string == "root") {
23625 this->set(LoopLevel::root());
23626 } else if (new_value_string == "inlined") {
23627 this->set(LoopLevel::inlined());
23628 } else {
23629 user_error << "Unable to parse " << this->name() << ": " << new_value_string;
23630 }
23631 }
23632
23633 std::string get_default_value() const override {
23634 // This is dodgy but safe in this case: we want to
23635 // see what the value of our LoopLevel is *right now*,
23636 // so we make a copy and lock the copy so we can inspect it.
23637 // (Note that ordinarily this is a bad idea, since LoopLevels
23638 // can be mutated later on; however, this method is only
23639 // called by the Generator infrastructure, on LoopLevels that
23640 // will never be mutated, so this is really just an elaborate way
23641 // to avoid runtime assertions.)
23642 LoopLevel copy;
23643 copy.set(this->value());
23644 copy.lock();
23645 if (copy.is_inlined()) {
23646 return "LoopLevel::inlined()";
23647 } else if (copy.is_root()) {
23648 return "LoopLevel::root()";
23649 } else {
23650 internal_error;
23651 return "";
23652 }
23653 }
23654
23655 std::string call_to_string(const std::string &v) const override {
23656 internal_error;
23657 return std::string();
23658 }
23659
23660 std::string get_c_type() const override {
23661 return "LoopLevel";
23662 }
23663
23664 bool is_looplevel_param() const override {
23665 return true;
23666 }
23667};
23668
23669template<typename T>
23670class GeneratorParam_Arithmetic : public GeneratorParamImpl<T> {
23671public:
23672 GeneratorParam_Arithmetic(const std::string &name,
23673 const T &value,
23674 const T &min = std::numeric_limits<T>::lowest(),
23675 const T &max = std::numeric_limits<T>::max())
23676 : GeneratorParamImpl<T>(name, value), min(min), max(max) {
23677 // call set() to ensure value is clamped to min/max
23678 this->set(value);
23679 }
23680
23681 void set_impl(const T &new_value) override {
23682 user_assert(new_value >= min && new_value <= max) << "Value out of range: " << new_value;
23683 GeneratorParamImpl<T>::set_impl(new_value);
23684 }
23685
23686 void set_from_string(const std::string &new_value_string) override {
23687 std::istringstream iss(new_value_string);
23688 T t;
23689 // All one-byte ints int8 and uint8 should be parsed as integers, not chars --
23690 // including 'char' itself. (Note that sizeof(bool) is often-but-not-always-1,
23691 // so be sure to exclude that case.)
23692 if (sizeof(T) == sizeof(char) && !std::is_same<T, bool>::value) {
23693 int i;
23694 iss >> i;
23695 t = (T)i;
23696 } else {
23697 iss >> t;
23698 }
23699 user_assert(!iss.fail() && iss.get() == EOF) << "Unable to parse: " << new_value_string;
23700 this->set(t);
23701 }
23702
23703 std::string get_default_value() const override {
23704 std::ostringstream oss;
23705 oss << this->value();
23706 if (std::is_same<T, float>::value) {
23707 // If the constant has no decimal point ("1")
23708 // we must append one before appending "f"
23709 if (oss.str().find('.') == std::string::npos) {
23710 oss << ".";
23711 }
23712 oss << "f";
23713 }
23714 return oss.str();
23715 }
23716
23717 std::string call_to_string(const std::string &v) const override {
23718 std::ostringstream oss;
23719 oss << "std::to_string(" << v << ")";
23720 return oss.str();
23721 }
23722
23723 std::string get_c_type() const override {
23724 std::ostringstream oss;
23725 if (std::is_same<T, float>::value) {
23726 return "float";
23727 } else if (std::is_same<T, double>::value) {
23728 return "double";
23729 } else if (std::is_integral<T>::value) {
23730 if (std::is_unsigned<T>::value) {
23731 oss << "u";
23732 }
23733 oss << "int" << (sizeof(T) * 8) << "_t";
23734 return oss.str();
23735 } else {
23736 user_error << "Unknown arithmetic type\n";
23737 return "";
23738 }
23739 }
23740
23741private:
23742 const T min, max;
23743};
23744
23745template<typename T>
23746class GeneratorParam_Bool : public GeneratorParam_Arithmetic<T> {
23747public:
23748 GeneratorParam_Bool(const std::string &name, const T &value)
23749 : GeneratorParam_Arithmetic<T>(name, value) {
23750 }
23751
23752 void set_from_string(const std::string &new_value_string) override {
23753 bool v = false;
23754 if (new_value_string == "true" || new_value_string == "True") {
23755 v = true;
23756 } else if (new_value_string == "false" || new_value_string == "False") {
23757 v = false;
23758 } else {
23759 user_assert(false) << "Unable to parse bool: " << new_value_string;
23760 }
23761 this->set(v);
23762 }
23763
23764 std::string get_default_value() const override {
23765 return this->value() ? "true" : "false";
23766 }
23767
23768 std::string call_to_string(const std::string &v) const override {
23769 std::ostringstream oss;
23770 oss << "std::string((" << v << ") ? \"true\" : \"false\")";
23771 return oss.str();
23772 }
23773
23774 std::string get_c_type() const override {
23775 return "bool";
23776 }
23777};
23778
23779template<typename T>
23780class GeneratorParam_Enum : public GeneratorParamImpl<T> {
23781public:
23782 GeneratorParam_Enum(const std::string &name, const T &value, const std::map<std::string, T> &enum_map)
23783 : GeneratorParamImpl<T>(name, value), enum_map(enum_map) {
23784 }
23785
23786 // define a "set" that takes our specific enum (but don't hide the inherited virtual functions)
23787 using GeneratorParamImpl<T>::set;
23788
23789 template<typename T2 = T, typename std::enable_if<!std::is_same<T2, Type>::value>::type * = nullptr>
23790 void set(const T &e) {
23791 this->set_impl(e);
23792 }
23793
23794 void set_from_string(const std::string &new_value_string) override {
23795 auto it = enum_map.find(new_value_string);
23796 user_assert(it != enum_map.end()) << "Enumeration value not found: " << new_value_string;
23797 this->set_impl(it->second);
23798 }
23799
23800 std::string call_to_string(const std::string &v) const override {
23801 return "Enum_" + this->name() + "_map().at(" + v + ")";
23802 }
23803
23804 std::string get_c_type() const override {
23805 return "Enum_" + this->name();
23806 }
23807
23808 std::string get_default_value() const override {
23809 return "Enum_" + this->name() + "::" + enum_to_string(enum_map, this->value());
23810 }
23811
23812 std::string get_type_decls() const override {
23813 std::ostringstream oss;
23814 oss << "enum class Enum_" << this->name() << " {\n";
23815 for (auto key_value : enum_map) {
23816 oss << " " << key_value.first << ",\n";
23817 }
23818 oss << "};\n";
23819 oss << "\n";
23820
23821 // TODO: since we generate the enums, we could probably just use a vector (or array!) rather than a map,
23822 // since we can ensure that the enum values are a nice tight range.
23823 oss << "inline HALIDE_NO_USER_CODE_INLINE const std::map<Enum_" << this->name() << ", std::string>& Enum_" << this->name() << "_map() {\n";
23824 oss << " static const std::map<Enum_" << this->name() << ", std::string> m = {\n";
23825 for (auto key_value : enum_map) {
23826 oss << " { Enum_" << this->name() << "::" << key_value.first << ", \"" << key_value.first << "\"},\n";
23827 }
23828 oss << " };\n";
23829 oss << " return m;\n";
23830 oss << "};\n";
23831 return oss.str();
23832 }
23833
23834private:
23835 const std::map<std::string, T> enum_map;
23836};
23837
23838template<typename T>
23839class GeneratorParam_Type : public GeneratorParam_Enum<T> {
23840public:
23841 GeneratorParam_Type(const std::string &name, const T &value)
23842 : GeneratorParam_Enum<T>(name, value, get_halide_type_enum_map()) {
23843 }
23844
23845 std::string call_to_string(const std::string &v) const override {
23846 return "Halide::Internal::halide_type_to_enum_string(" + v + ")";
23847 }
23848
23849 std::string get_c_type() const override {
23850 return "Type";
23851 }
23852
23853 std::string get_default_value() const override {
23854 return halide_type_to_c_source(this->value());
23855 }
23856
23857 std::string get_type_decls() const override {
23858 return "";
23859 }
23860};
23861
23862template<typename T>
23863class GeneratorParam_String : public Internal::GeneratorParamImpl<T> {
23864public:
23865 GeneratorParam_String(const std::string &name, const std::string &value)
23866 : GeneratorParamImpl<T>(name, value) {
23867 }
23868 void set_from_string(const std::string &new_value_string) override {
23869 this->set(new_value_string);
23870 }
23871
23872 std::string get_default_value() const override {
23873 return "\"" + this->value() + "\"";
23874 }
23875
23876 std::string call_to_string(const std::string &v) const override {
23877 return v;
23878 }
23879
23880 std::string get_c_type() const override {
23881 return "std::string";
23882 }
23883};
23884
23885template<typename T>
23886using GeneratorParamImplBase =
23887 typename select_type<
23888 cond<std::is_same<T, Target>::value, GeneratorParam_Target<T>>,
23889 cond<std::is_same<T, MachineParams>::value, GeneratorParam_MachineParams<T>>,
23890 cond<std::is_same<T, LoopLevel>::value, GeneratorParam_LoopLevel>,
23891 cond<std::is_same<T, std::string>::value, GeneratorParam_String<T>>,
23892 cond<std::is_same<T, Type>::value, GeneratorParam_Type<T>>,
23893 cond<std::is_same<T, bool>::value, GeneratorParam_Bool<T>>,
23894 cond<std::is_arithmetic<T>::value, GeneratorParam_Arithmetic<T>>,
23895 cond<std::is_enum<T>::value, GeneratorParam_Enum<T>>>::type;
23896
23897} // namespace Internal
23898
23899/** GeneratorParam is a templated class that can be used to modify the behavior
23900 * of the Generator at code-generation time. GeneratorParams are commonly
23901 * specified in build files (e.g. Makefile) to customize the behavior of
23902 * a given Generator, thus they have a very constrained set of types to allow
23903 * for efficient specification via command-line flags. A GeneratorParam can be:
23904 * - any float or int type.
23905 * - bool
23906 * - enum
23907 * - Halide::Target
23908 * - Halide::Type
23909 * - std::string
23910 * Please don't use std::string unless there's no way to do what you want with some
23911 * other type; in particular, don't use this if you can use enum instead.
23912 * All GeneratorParams have a default value. Arithmetic types can also
23913 * optionally specify min and max. Enum types must specify a string-to-value
23914 * map.
23915 *
23916 * Halide::Type is treated as though it were an enum, with the mappings:
23917 *
23918 * "int8" Halide::Int(8)
23919 * "int16" Halide::Int(16)
23920 * "int32" Halide::Int(32)
23921 * "uint8" Halide::UInt(8)
23922 * "uint16" Halide::UInt(16)
23923 * "uint32" Halide::UInt(32)
23924 * "float32" Halide::Float(32)
23925 * "float64" Halide::Float(64)
23926 *
23927 * No vector Types are currently supported by this mapping.
23928 *
23929 */
23930template<typename T>
23931class GeneratorParam : public Internal::GeneratorParamImplBase<T> {
23932public:
23933 template<typename T2 = T, typename std::enable_if<!std::is_same<T2, std::string>::value>::type * = nullptr>
23934 GeneratorParam(const std::string &name, const T &value)
23935 : Internal::GeneratorParamImplBase<T>(name, value) {
23936 }
23937
23938 GeneratorParam(const std::string &name, const T &value, const T &min, const T &max)
23939 : Internal::GeneratorParamImplBase<T>(name, value, min, max) {
23940 }
23941
23942 GeneratorParam(const std::string &name, const T &value, const std::map<std::string, T> &enum_map)
23943 : Internal::GeneratorParamImplBase<T>(name, value, enum_map) {
23944 }
23945
23946 GeneratorParam(const std::string &name, const std::string &value)
23947 : Internal::GeneratorParamImplBase<T>(name, value) {
23948 }
23949};
23950
23951/** Addition between GeneratorParam<T> and any type that supports operator+ with T.
23952 * Returns type of underlying operator+. */
23953// @{
23954template<typename Other, typename T>
23955auto operator+(const Other &a, const GeneratorParam<T> &b) -> decltype(a + (T)b) {
23956 return a + (T)b;
23957}
23958template<typename Other, typename T>
23959auto operator+(const GeneratorParam<T> &a, const Other &b) -> decltype((T)a + b) {
23960 return (T)a + b;
23961}
23962// @}
23963
23964/** Subtraction between GeneratorParam<T> and any type that supports operator- with T.
23965 * Returns type of underlying operator-. */
23966// @{
23967template<typename Other, typename T>
23968auto operator-(const Other &a, const GeneratorParam<T> &b) -> decltype(a - (T)b) {
23969 return a - (T)b;
23970}
23971template<typename Other, typename T>
23972auto operator-(const GeneratorParam<T> &a, const Other &b) -> decltype((T)a - b) {
23973 return (T)a - b;
23974}
23975// @}
23976
23977/** Multiplication between GeneratorParam<T> and any type that supports operator* with T.
23978 * Returns type of underlying operator*. */
23979// @{
23980template<typename Other, typename T>
23981auto operator*(const Other &a, const GeneratorParam<T> &b) -> decltype(a * (T)b) {
23982 return a * (T)b;
23983}
23984template<typename Other, typename T>
23985auto operator*(const GeneratorParam<T> &a, const Other &b) -> decltype((T)a * b) {
23986 return (T)a * b;
23987}
23988// @}
23989
23990/** Division between GeneratorParam<T> and any type that supports operator/ with T.
23991 * Returns type of underlying operator/. */
23992// @{
23993template<typename Other, typename T>
23994auto operator/(const Other &a, const GeneratorParam<T> &b) -> decltype(a / (T)b) {
23995 return a / (T)b;
23996}
23997template<typename Other, typename T>
23998auto operator/(const GeneratorParam<T> &a, const Other &b) -> decltype((T)a / b) {
23999 return (T)a / b;
24000}
24001// @}
24002
24003/** Modulo between GeneratorParam<T> and any type that supports operator% with T.
24004 * Returns type of underlying operator%. */
24005// @{
24006template<typename Other, typename T>
24007auto operator%(const Other &a, const GeneratorParam<T> &b) -> decltype(a % (T)b) {
24008 return a % (T)b;
24009}
24010template<typename Other, typename T>
24011auto operator%(const GeneratorParam<T> &a, const Other &b) -> decltype((T)a % b) {
24012 return (T)a % b;
24013}
24014// @}
24015
24016/** Greater than comparison between GeneratorParam<T> and any type that supports operator> with T.
24017 * Returns type of underlying operator>. */
24018// @{
24019template<typename Other, typename T>
24020auto operator>(const Other &a, const GeneratorParam<T> &b) -> decltype(a > (T)b) {
24021 return a > (T)b;
24022}
24023template<typename Other, typename T>
24024auto operator>(const GeneratorParam<T> &a, const Other &b) -> decltype((T)a > b) {
24025 return (T)a > b;
24026}
24027// @}
24028
24029/** Less than comparison between GeneratorParam<T> and any type that supports operator< with T.
24030 * Returns type of underlying operator<. */
24031// @{
24032template<typename Other, typename T>
24033auto operator<(const Other &a, const GeneratorParam<T> &b) -> decltype(a < (T)b) {
24034 return a < (T)b;
24035}
24036template<typename Other, typename T>
24037auto operator<(const GeneratorParam<T> &a, const Other &b) -> decltype((T)a < b) {
24038 return (T)a < b;
24039}
24040// @}
24041
24042/** Greater than or equal comparison between GeneratorParam<T> and any type that supports operator>= with T.
24043 * Returns type of underlying operator>=. */
24044// @{
24045template<typename Other, typename T>
24046auto operator>=(const Other &a, const GeneratorParam<T> &b) -> decltype(a >= (T)b) {
24047 return a >= (T)b;
24048}
24049template<typename Other, typename T>
24050auto operator>=(const GeneratorParam<T> &a, const Other &b) -> decltype((T)a >= b) {
24051 return (T)a >= b;
24052}
24053// @}
24054
24055/** Less than or equal comparison between GeneratorParam<T> and any type that supports operator<= with T.
24056 * Returns type of underlying operator<=. */
24057// @{
24058template<typename Other, typename T>
24059auto operator<=(const Other &a, const GeneratorParam<T> &b) -> decltype(a <= (T)b) {
24060 return a <= (T)b;
24061}
24062template<typename Other, typename T>
24063auto operator<=(const GeneratorParam<T> &a, const Other &b) -> decltype((T)a <= b) {
24064 return (T)a <= b;
24065}
24066// @}
24067
24068/** Equality comparison between GeneratorParam<T> and any type that supports operator== with T.
24069 * Returns type of underlying operator==. */
24070// @{
24071template<typename Other, typename T>
24072auto operator==(const Other &a, const GeneratorParam<T> &b) -> decltype(a == (T)b) {
24073 return a == (T)b;
24074}
24075template<typename Other, typename T>
24076auto operator==(const GeneratorParam<T> &a, const Other &b) -> decltype((T)a == b) {
24077 return (T)a == b;
24078}
24079// @}
24080
24081/** Inequality comparison between between GeneratorParam<T> and any type that supports operator!= with T.
24082 * Returns type of underlying operator!=. */
24083// @{
24084template<typename Other, typename T>
24085auto operator!=(const Other &a, const GeneratorParam<T> &b) -> decltype(a != (T)b) {
24086 return a != (T)b;
24087}
24088template<typename Other, typename T>
24089auto operator!=(const GeneratorParam<T> &a, const Other &b) -> decltype((T)a != b) {
24090 return (T)a != b;
24091}
24092// @}
24093
24094/** Logical and between between GeneratorParam<T> and any type that supports operator&& with T.
24095 * Returns type of underlying operator&&. */
24096// @{
24097template<typename Other, typename T>
24098auto operator&&(const Other &a, const GeneratorParam<T> &b) -> decltype(a && (T)b) {
24099 return a && (T)b;
24100}
24101template<typename Other, typename T>
24102auto operator&&(const GeneratorParam<T> &a, const Other &b) -> decltype((T)a && b) {
24103 return (T)a && b;
24104}
24105template<typename T>
24106auto operator&&(const GeneratorParam<T> &a, const GeneratorParam<T> &b) -> decltype((T)a && (T)b) {
24107 return (T)a && (T)b;
24108}
24109// @}
24110
24111/** Logical or between between GeneratorParam<T> and any type that supports operator|| with T.
24112 * Returns type of underlying operator||. */
24113// @{
24114template<typename Other, typename T>
24115auto operator||(const Other &a, const GeneratorParam<T> &b) -> decltype(a || (T)b) {
24116 return a || (T)b;
24117}
24118template<typename Other, typename T>
24119auto operator||(const GeneratorParam<T> &a, const Other &b) -> decltype((T)a || b) {
24120 return (T)a || b;
24121}
24122template<typename T>
24123auto operator||(const GeneratorParam<T> &a, const GeneratorParam<T> &b) -> decltype((T)a || (T)b) {
24124 return (T)a || (T)b;
24125}
24126// @}
24127
24128/* min and max are tricky as the language support for these is in the std
24129 * namespace. In order to make this work, forwarding functions are used that
24130 * are declared in a namespace that has std::min and std::max in scope.
24131 */
24132namespace Internal {
24133namespace GeneratorMinMax {
24134
24135using std::max;
24136using std::min;
24137
24138template<typename Other, typename T>
24139auto min_forward(const Other &a, const GeneratorParam<T> &b) -> decltype(min(a, (T)b)) {
24140 return min(a, (T)b);
24141}
24142template<typename Other, typename T>
24143auto min_forward(const GeneratorParam<T> &a, const Other &b) -> decltype(min((T)a, b)) {
24144 return min((T)a, b);
24145}
24146
24147template<typename Other, typename T>
24148auto max_forward(const Other &a, const GeneratorParam<T> &b) -> decltype(max(a, (T)b)) {
24149 return max(a, (T)b);
24150}
24151template<typename Other, typename T>
24152auto max_forward(const GeneratorParam<T> &a, const Other &b) -> decltype(max((T)a, b)) {
24153 return max((T)a, b);
24154}
24155
24156} // namespace GeneratorMinMax
24157} // namespace Internal
24158
24159/** Compute minimum between GeneratorParam<T> and any type that supports min with T.
24160 * Will automatically import std::min. Returns type of underlying min call. */
24161// @{
24162template<typename Other, typename T>
24163auto min(const Other &a, const GeneratorParam<T> &b) -> decltype(Internal::GeneratorMinMax::min_forward(a, b)) {
24164 return Internal::GeneratorMinMax::min_forward(a, b);
24165}
24166template<typename Other, typename T>
24167auto min(const GeneratorParam<T> &a, const Other &b) -> decltype(Internal::GeneratorMinMax::min_forward(a, b)) {
24168 return Internal::GeneratorMinMax::min_forward(a, b);
24169}
24170// @}
24171
24172/** Compute the maximum value between GeneratorParam<T> and any type that supports max with T.
24173 * Will automatically import std::max. Returns type of underlying max call. */
24174// @{
24175template<typename Other, typename T>
24176auto max(const Other &a, const GeneratorParam<T> &b) -> decltype(Internal::GeneratorMinMax::max_forward(a, b)) {
24177 return Internal::GeneratorMinMax::max_forward(a, b);
24178}
24179template<typename Other, typename T>
24180auto max(const GeneratorParam<T> &a, const Other &b) -> decltype(Internal::GeneratorMinMax::max_forward(a, b)) {
24181 return Internal::GeneratorMinMax::max_forward(a, b);
24182}
24183// @}
24184
24185/** Not operator for GeneratorParam */
24186template<typename T>
24187auto operator!(const GeneratorParam<T> &a) -> decltype(!(T)a) {
24188 return !(T)a;
24189}
24190
24191namespace Internal {
24192
24193template<typename T2>
24194class GeneratorInput_Buffer;
24195
24196enum class IOKind { Scalar,
24197 Function,
24198 Buffer };
24199
24200/**
24201 * StubInputBuffer is the placeholder that a Stub uses when it requires
24202 * a Buffer for an input (rather than merely a Func or Expr). It is constructed
24203 * to allow only two possible sorts of input:
24204 * -- Assignment of an Input<Buffer<>>, with compatible type and dimensions,
24205 * essentially allowing us to pipe a parameter from an enclosing Generator to an internal Stub.
24206 * -- Assignment of a Buffer<>, with compatible type and dimensions,
24207 * causing the Input<Buffer<>> to become a precompiled buffer in the generated code.
24208 */
24209template<typename T = void>
24210class StubInputBuffer {
24211 friend class StubInput;
24212 template<typename T2>
24213 friend class GeneratorInput_Buffer;
24214
24215 Parameter parameter_;
24216
24217 HALIDE_NO_USER_CODE_INLINE explicit StubInputBuffer(const Parameter &p)
24218 : parameter_(p) {
24219 // Create an empty 1-element buffer with the right runtime typing and dimensions,
24220 // which we'll use only to pass to can_convert_from() to verify this
24221 // Parameter is compatible with our constraints.
24222 Buffer<> other(p.type(), nullptr, std::vector<int>(p.dimensions(), 1));
24223 internal_assert((Buffer<T>::can_convert_from(other)));
24224 }
24225
24226 template<typename T2>
24227 HALIDE_NO_USER_CODE_INLINE static Parameter parameter_from_buffer(const Buffer<T2> &b) {
24228 internal_assert(b.defined());
24229 user_assert((Buffer<T>::can_convert_from(b)));
24230 Parameter p(b.type(), true, b.dimensions());
24231 p.set_buffer(b);
24232 return p;
24233 }
24234
24235public:
24236 StubInputBuffer() = default;
24237
24238 // *not* explicit -- this ctor should only be used when you want
24239 // to pass a literal Buffer<> for a Stub Input; this Buffer<> will be
24240 // compiled into the Generator's product, rather than becoming
24241 // a runtime Parameter.
24242 template<typename T2>
24243 StubInputBuffer(const Buffer<T2> &b)
24244 : parameter_(parameter_from_buffer(b)) {
24245 }
24246};
24247
24248class StubOutputBufferBase {
24249protected:
24250 Func f;
24251 std::shared_ptr<GeneratorBase> generator;
24252
24253 void check_scheduled(const char *m) const;
24254 Target get_target() const;
24255
24256 StubOutputBufferBase();
24257 explicit StubOutputBufferBase(const Func &f, const std::shared_ptr<GeneratorBase> &generator);
24258
24259public:
24260 Realization realize(std::vector<int32_t> sizes);
24261
24262 template<typename... Args>
24263 Realization realize(Args &&...args) {
24264 check_scheduled("realize");
24265 return f.realize(std::forward<Args>(args)..., get_target());
24266 }
24267
24268 template<typename Dst>
24269 void realize(Dst dst) {
24270 check_scheduled("realize");
24271 f.realize(dst, get_target());
24272 }
24273};
24274
24275/**
24276 * StubOutputBuffer is the placeholder that a Stub uses when it requires
24277 * a Buffer for an output (rather than merely a Func). It is constructed
24278 * to allow only two possible sorts of things:
24279 * -- Assignment to an Output<Buffer<>>, with compatible type and dimensions,
24280 * essentially allowing us to pipe a parameter from the result of a Stub to an
24281 * enclosing Generator
24282 * -- Realization into a Buffer<>; this is useful only in JIT compilation modes
24283 * (and shouldn't be usable otherwise)
24284 *
24285 * It is deliberate that StubOutputBuffer is not (easily) convertible to Func.
24286 */
24287template<typename T = void>
24288class StubOutputBuffer : public StubOutputBufferBase {
24289 template<typename T2>
24290 friend class GeneratorOutput_Buffer;
24291 friend class GeneratorStub;
24292 explicit StubOutputBuffer(const Func &f, const std::shared_ptr<GeneratorBase> &generator)
24293 : StubOutputBufferBase(f, generator) {
24294 }
24295
24296public:
24297 StubOutputBuffer() = default;
24298};
24299
24300// This is a union-like class that allows for convenient initialization of Stub Inputs
24301// via C++11 initializer-list syntax; it is only used in situations where the
24302// downstream consumer will be able to explicitly check that each value is
24303// of the expected/required kind.
24304class StubInput {
24305 const IOKind kind_;
24306 // Exactly one of the following fields should be defined:
24307 const Parameter parameter_;
24308 const Func func_;
24309 const Expr expr_;
24310
24311public:
24312 // *not* explicit.
24313 template<typename T2>
24314 StubInput(const StubInputBuffer<T2> &b)
24315 : kind_(IOKind::Buffer), parameter_(b.parameter_), func_(), expr_() {
24316 }
24317 StubInput(const Func &f)
24318 : kind_(IOKind::Function), parameter_(), func_(f), expr_() {
24319 }
24320 StubInput(const Expr &e)
24321 : kind_(IOKind::Scalar), parameter_(), func_(), expr_(e) {
24322 }
24323
24324private:
24325 friend class GeneratorInputBase;
24326
24327 IOKind kind() const {
24328 return kind_;
24329 }
24330
24331 Parameter parameter() const {
24332 internal_assert(kind_ == IOKind::Buffer);
24333 return parameter_;
24334 }
24335
24336 Func func() const {
24337 internal_assert(kind_ == IOKind::Function);
24338 return func_;
24339 }
24340
24341 Expr expr() const {
24342 internal_assert(kind_ == IOKind::Scalar);
24343 return expr_;
24344 }
24345};
24346
24347/** GIOBase is the base class for all GeneratorInput<> and GeneratorOutput<>
24348 * instantiations; it is not part of the public API and should never be
24349 * used directly by user code.
24350 *
24351 * Every GIOBase instance can be either a single value or an array-of-values;
24352 * each of these values can be an Expr or a Func. (Note that for an
24353 * array-of-values, the types/dimensions of all values in the array must match.)
24354 *
24355 * A GIOBase can have multiple Types, in which case it represents a Tuple.
24356 * (Note that Tuples are currently only supported for GeneratorOutput, but
24357 * it is likely that GeneratorInput will be extended to support Tuple as well.)
24358 *
24359 * The array-size, type(s), and dimensions can all be left "unspecified" at
24360 * creation time, in which case they may assume values provided by a Stub.
24361 * (It is important to note that attempting to use a GIOBase with unspecified
24362 * values will assert-fail; you must ensure that all unspecified values are
24363 * filled in prior to use.)
24364 */
24365class GIOBase {
24366public:
24367 bool array_size_defined() const;
24368 size_t array_size() const;
24369 virtual bool is_array() const;
24370
24371 const std::string &name() const;
24372 IOKind kind() const;
24373
24374 bool types_defined() const;
24375 const std::vector<Type> &types() const;
24376 Type type() const;
24377
24378 bool dims_defined() const;
24379 int dims() const;
24380
24381 const std::vector<Func> &funcs() const;
24382 const std::vector<Expr> &exprs() const;
24383
24384 virtual ~GIOBase() = default;
24385
24386 void set_type(const Type &type);
24387 void set_dimensions(int dims);
24388 void set_array_size(int size);
24389
24390protected:
24391 GIOBase(size_t array_size,
24392 const std::string &name,
24393 IOKind kind,
24394 const std::vector<Type> &types,
24395 int dims);
24396
24397 friend class GeneratorBase;
24398 friend class GeneratorParamInfo;
24399
24400 mutable int array_size_; // always 1 if is_array() == false.
24401 // -1 if is_array() == true but unspecified.
24402
24403 const std::string name_;
24404 const IOKind kind_;
24405 mutable std::vector<Type> types_; // empty if type is unspecified
24406 mutable int dims_; // -1 if dim is unspecified
24407
24408 // Exactly one of these will have nonzero length
24409 std::vector<Func> funcs_;
24410 std::vector<Expr> exprs_;
24411
24412 // Generator which owns this Input or Output. Note that this will be null
24413 // initially; the GeneratorBase itself will set this field when it initially
24414 // builds its info about params. However, since it isn't
24415 // appropriate for Input<> or Output<> to be declared outside of a Generator,
24416 // all reasonable non-testing code should expect this to be non-null.
24417 GeneratorBase *generator{nullptr};
24418
24419 std::string array_name(size_t i) const;
24420
24421 virtual void verify_internals();
24422
24423 void check_matching_array_size(size_t size) const;
24424 void check_matching_types(const std::vector<Type> &t) const;
24425 void check_matching_dims(int d) const;
24426
24427 template<typename ElemType>
24428 const std::vector<ElemType> &get_values() const;
24429
24430 void check_gio_access() const;
24431
24432 virtual void check_value_writable() const = 0;
24433
24434 virtual const char *input_or_output() const = 0;
24435
24436private:
24437 template<typename T>
24438 friend class GeneratorParam_Synthetic;
24439
24440public:
24441 GIOBase(const GIOBase &) = delete;
24442 GIOBase &operator=(const GIOBase &) = delete;
24443 GIOBase(GIOBase &&) = delete;
24444 GIOBase &operator=(GIOBase &&) = delete;
24445};
24446
24447template<>
24448inline const std::vector<Expr> &GIOBase::get_values<Expr>() const {
24449 return exprs();
24450}
24451
24452template<>
24453inline const std::vector<Func> &GIOBase::get_values<Func>() const {
24454 return funcs();
24455}
24456
24457class GeneratorInputBase : public GIOBase {
24458protected:
24459 GeneratorInputBase(size_t array_size,
24460 const std::string &name,
24461 IOKind kind,
24462 const std::vector<Type> &t,
24463 int d);
24464
24465 GeneratorInputBase(const std::string &name, IOKind kind, const std::vector<Type> &t, int d);
24466
24467 friend class GeneratorBase;
24468 friend class GeneratorParamInfo;
24469
24470 std::vector<Parameter> parameters_;
24471
24472 Parameter parameter() const;
24473
24474 void init_internals();
24475 void set_inputs(const std::vector<StubInput> &inputs);
24476
24477 virtual void set_def_min_max();
24478
24479 void verify_internals() override;
24480
24481 friend class StubEmitter;
24482
24483 virtual std::string get_c_type() const = 0;
24484
24485 void check_value_writable() const override;
24486
24487 const char *input_or_output() const override {
24488 return "Input";
24489 }
24490
24491 void set_estimate_impl(const Var &var, const Expr &min, const Expr &extent);
24492 void set_estimates_impl(const Region &estimates);
24493
24494public:
24495 ~GeneratorInputBase() override;
24496};
24497
24498template<typename T, typename ValueType>
24499class GeneratorInputImpl : public GeneratorInputBase {
24500protected:
24501 using TBase = typename std::remove_all_extents<T>::type;
24502
24503 bool is_array() const override {
24504 return std::is_array<T>::value;
24505 }
24506
24507 template<typename T2 = T, typename std::enable_if<
24508 // Only allow T2 not-an-array
24509 !std::is_array<T2>::value>::type * = nullptr>
24510 GeneratorInputImpl(const std::string &name, IOKind kind, const std::vector<Type> &t, int d)
24511 : GeneratorInputBase(name, kind, t, d) {
24512 }
24513
24514 template<typename T2 = T, typename std::enable_if<
24515 // Only allow T2[kSomeConst]
24516 std::is_array<T2>::value && std::rank<T2>::value == 1 && (std::extent<T2, 0>::value > 0)>::type * = nullptr>
24517 GeneratorInputImpl(const std::string &name, IOKind kind, const std::vector<Type> &t, int d)
24518 : GeneratorInputBase(std::extent<T2, 0>::value, name, kind, t, d) {
24519 }
24520
24521 template<typename T2 = T, typename std::enable_if<
24522 // Only allow T2[]
24523 std::is_array<T2>::value && std::rank<T2>::value == 1 && std::extent<T2, 0>::value == 0>::type * = nullptr>
24524 GeneratorInputImpl(const std::string &name, IOKind kind, const std::vector<Type> &t, int d)
24525 : GeneratorInputBase(-1, name, kind, t, d) {
24526 }
24527
24528public:
24529 template<typename T2 = T, typename std::enable_if<std::is_array<T2>::value>::type * = nullptr>
24530 size_t size() const {
24531 this->check_gio_access();
24532 return get_values<ValueType>().size();
24533 }
24534
24535 template<typename T2 = T, typename std::enable_if<std::is_array<T2>::value>::type * = nullptr>
24536 const ValueType &operator[](size_t i) const {
24537 this->check_gio_access();
24538 return get_values<ValueType>()[i];
24539 }
24540
24541 template<typename T2 = T, typename std::enable_if<std::is_array<T2>::value>::type * = nullptr>
24542 const ValueType &at(size_t i) const {
24543 this->check_gio_access();
24544 return get_values<ValueType>().at(i);
24545 }
24546
24547 template<typename T2 = T, typename std::enable_if<std::is_array<T2>::value>::type * = nullptr>
24548 typename std::vector<ValueType>::const_iterator begin() const {
24549 this->check_gio_access();
24550 return get_values<ValueType>().begin();
24551 }
24552
24553 template<typename T2 = T, typename std::enable_if<std::is_array<T2>::value>::type * = nullptr>
24554 typename std::vector<ValueType>::const_iterator end() const {
24555 this->check_gio_access();
24556 return get_values<ValueType>().end();
24557 }
24558};
24559
24560// When forwarding methods to ImageParam, Func, etc., we must take
24561// care with the return types: many of the methods return a reference-to-self
24562// (e.g., ImageParam&); since we create temporaries for most of these forwards,
24563// returning a ref will crater because it refers to a now-defunct section of the
24564// stack. Happily, simply removing the reference is solves this, since all of the
24565// types in question satisfy the property of copies referring to the same underlying
24566// structure (returning references is just an optimization). Since this is verbose
24567// and used in several places, we'll use a helper macro:
24568#define HALIDE_FORWARD_METHOD(Class, Method) \
24569 template<typename... Args> \
24570 inline auto Method(Args &&...args)->typename std::remove_reference<decltype(std::declval<Class>().Method(std::forward<Args>(args)...))>::type { \
24571 return this->template as<Class>().Method(std::forward<Args>(args)...); \
24572 }
24573
24574#define HALIDE_FORWARD_METHOD_CONST(Class, Method) \
24575 template<typename... Args> \
24576 inline auto Method(Args &&...args) const-> \
24577 typename std::remove_reference<decltype(std::declval<Class>().Method(std::forward<Args>(args)...))>::type { \
24578 this->check_gio_access(); \
24579 return this->template as<Class>().Method(std::forward<Args>(args)...); \
24580 }
24581
24582template<typename T>
24583class GeneratorInput_Buffer : public GeneratorInputImpl<T, Func> {
24584private:
24585 using Super = GeneratorInputImpl<T, Func>;
24586
24587protected:
24588 using TBase = typename Super::TBase;
24589
24590 friend class ::Halide::Func;
24591 friend class ::Halide::Stage;
24592
24593 std::string get_c_type() const override {
24594 if (TBase::has_static_halide_type) {
24595 return "Halide::Internal::StubInputBuffer<" +
24596 halide_type_to_c_type(TBase::static_halide_type()) +
24597 ">";
24598 } else {
24599 return "Halide::Internal::StubInputBuffer<>";
24600 }
24601 }
24602
24603 template<typename T2>
24604 inline T2 as() const {
24605 return (T2) * this;
24606 }
24607
24608public:
24609 GeneratorInput_Buffer(const std::string &name)
24610 : Super(name, IOKind::Buffer,
24611 TBase::has_static_halide_type ? std::vector<Type>{TBase::static_halide_type()} : std::vector<Type>{},
24612 -1) {
24613 }
24614
24615 GeneratorInput_Buffer(const std::string &name, const Type &t, int d = -1)
24616 : Super(name, IOKind::Buffer, {t}, d) {
24617 static_assert(!TBase::has_static_halide_type, "You can only specify a Type argument for Input<Buffer<T>> if T is void or omitted.");
24618 }
24619
24620 GeneratorInput_Buffer(const std::string &name, int d)
24621 : Super(name, IOKind::Buffer, TBase::has_static_halide_type ? std::vector<Type>{TBase::static_halide_type()} : std::vector<Type>{}, d) {
24622 }
24623
24624 template<typename... Args>
24625 Expr operator()(Args &&...args) const {
24626 this->check_gio_access();
24627 return Func(*this)(std::forward<Args>(args)...);
24628 }
24629
24630 Expr operator()(std::vector<Expr> args) const {
24631 this->check_gio_access();
24632 return Func(*this)(std::move(args));
24633 }
24634
24635 template<typename T2>
24636 operator StubInputBuffer<T2>() const {
24637 user_assert(!this->is_array()) << "Cannot assign an array type to a non-array type for Input " << this->name();
24638 return StubInputBuffer<T2>(this->parameters_.at(0));
24639 }
24640
24641 operator Func() const {
24642 this->check_gio_access();
24643 return this->funcs().at(0);
24644 }
24645
24646 operator ExternFuncArgument() const {
24647 this->check_gio_access();
24648 return ExternFuncArgument(this->parameters_.at(0));
24649 }
24650
24651 GeneratorInput_Buffer<T> &set_estimate(Var var, Expr min, Expr extent) {
24652 this->check_gio_access();
24653 this->set_estimate_impl(var, min, extent);
24654 return *this;
24655 }
24656
24657 GeneratorInput_Buffer<T> &set_estimates(const Region &estimates) {
24658 this->check_gio_access();
24659 this->set_estimates_impl(estimates);
24660 return *this;
24661 }
24662
24663 Func in() {
24664 this->check_gio_access();
24665 return Func(*this).in();
24666 }
24667
24668 Func in(const Func &other) {
24669 this->check_gio_access();
24670 return Func(*this).in(other);
24671 }
24672
24673 Func in(const std::vector<Func> &others) {
24674 this->check_gio_access();
24675 return Func(*this).in(others);
24676 }
24677
24678 operator ImageParam() const {
24679 this->check_gio_access();
24680 user_assert(!this->is_array()) << "Cannot convert an Input<Buffer<>[]> to an ImageParam; use an explicit subscript operator: " << this->name();
24681 return ImageParam(this->parameters_.at(0), Func(*this));
24682 }
24683
24684 template<typename T2 = T, typename std::enable_if<std::is_array<T2>::value>::type * = nullptr>
24685 size_t size() const {
24686 this->check_gio_access();
24687 return this->parameters_.size();
24688 }
24689
24690 template<typename T2 = T, typename std::enable_if<std::is_array<T2>::value>::type * = nullptr>
24691 ImageParam operator[](size_t i) const {
24692 this->check_gio_access();
24693 return ImageParam(this->parameters_.at(i), this->funcs().at(i));
24694 }
24695
24696 template<typename T2 = T, typename std::enable_if<std::is_array<T2>::value>::type * = nullptr>
24697 ImageParam at(size_t i) const {
24698 this->check_gio_access();
24699 return ImageParam(this->parameters_.at(i), this->funcs().at(i));
24700 }
24701
24702 template<typename T2 = T, typename std::enable_if<std::is_array<T2>::value>::type * = nullptr>
24703 typename std::vector<ImageParam>::const_iterator begin() const {
24704 user_error << "Input<Buffer<>>::begin() is not supported.";
24705 return {};
24706 }
24707
24708 template<typename T2 = T, typename std::enable_if<std::is_array<T2>::value>::type * = nullptr>
24709 typename std::vector<ImageParam>::const_iterator end() const {
24710 user_error << "Input<Buffer<>>::end() is not supported.";
24711 return {};
24712 }
24713
24714 /** Forward methods to the ImageParam. */
24715 // @{
24716 HALIDE_FORWARD_METHOD(ImageParam, dim)
24717 HALIDE_FORWARD_METHOD_CONST(ImageParam, dim)
24718 HALIDE_FORWARD_METHOD_CONST(ImageParam, host_alignment)
24719 HALIDE_FORWARD_METHOD(ImageParam, set_host_alignment)
24720 HALIDE_FORWARD_METHOD(ImageParam, store_in)
24721 HALIDE_FORWARD_METHOD_CONST(ImageParam, dimensions)
24722 HALIDE_FORWARD_METHOD_CONST(ImageParam, left)
24723 HALIDE_FORWARD_METHOD_CONST(ImageParam, right)
24724 HALIDE_FORWARD_METHOD_CONST(ImageParam, top)
24725 HALIDE_FORWARD_METHOD_CONST(ImageParam, bottom)
24726 HALIDE_FORWARD_METHOD_CONST(ImageParam, width)
24727 HALIDE_FORWARD_METHOD_CONST(ImageParam, height)
24728 HALIDE_FORWARD_METHOD_CONST(ImageParam, channels)
24729 HALIDE_FORWARD_METHOD_CONST(ImageParam, trace_loads)
24730 HALIDE_FORWARD_METHOD_CONST(ImageParam, add_trace_tag)
24731 // }@
24732};
24733
24734template<typename T>
24735class GeneratorInput_Func : public GeneratorInputImpl<T, Func> {
24736private:
24737 using Super = GeneratorInputImpl<T, Func>;
24738
24739protected:
24740 using TBase = typename Super::TBase;
24741
24742 std::string get_c_type() const override {
24743 return "Func";
24744 }
24745
24746 template<typename T2>
24747 inline T2 as() const {
24748 return (T2) * this;
24749 }
24750
24751public:
24752 GeneratorInput_Func(const std::string &name, const Type &t, int d)
24753 : Super(name, IOKind::Function, {t}, d) {
24754 }
24755
24756 // unspecified type
24757 GeneratorInput_Func(const std::string &name, int d)
24758 : Super(name, IOKind::Function, {}, d) {
24759 }
24760
24761 // unspecified dimension
24762 GeneratorInput_Func(const std::string &name, const Type &t)
24763 : Super(name, IOKind::Function, {t}, -1) {
24764 }
24765
24766 // unspecified type & dimension
24767 GeneratorInput_Func(const std::string &name)
24768 : Super(name, IOKind::Function, {}, -1) {
24769 }
24770
24771 GeneratorInput_Func(size_t array_size, const std::string &name, const Type &t, int d)
24772 : Super(array_size, name, IOKind::Function, {t}, d) {
24773 }
24774
24775 // unspecified type
24776 GeneratorInput_Func(size_t array_size, const std::string &name, int d)
24777 : Super(array_size, name, IOKind::Function, {}, d) {
24778 }
24779
24780 // unspecified dimension
24781 GeneratorInput_Func(size_t array_size, const std::string &name, const Type &t)
24782 : Super(array_size, name, IOKind::Function, {t}, -1) {
24783 }
24784
24785 // unspecified type & dimension
24786 GeneratorInput_Func(size_t array_size, const std::string &name)
24787 : Super(array_size, name, IOKind::Function, {}, -1) {
24788 }
24789
24790 template<typename... Args>
24791 Expr operator()(Args &&...args) const {
24792 this->check_gio_access();
24793 return this->funcs().at(0)(std::forward<Args>(args)...);
24794 }
24795
24796 Expr operator()(const std::vector<Expr> &args) const {
24797 this->check_gio_access();
24798 return this->funcs().at(0)(args);
24799 }
24800
24801 operator Func() const {
24802 this->check_gio_access();
24803 return this->funcs().at(0);
24804 }
24805
24806 operator ExternFuncArgument() const {
24807 this->check_gio_access();
24808 return ExternFuncArgument(this->parameters_.at(0));
24809 }
24810
24811 GeneratorInput_Func<T> &set_estimate(Var var, Expr min, Expr extent) {
24812 this->check_gio_access();
24813 this->set_estimate_impl(var, min, extent);
24814 return *this;
24815 }
24816
24817 GeneratorInput_Func<T> &set_estimates(const Region &estimates) {
24818 this->check_gio_access();
24819 this->set_estimates_impl(estimates);
24820 return *this;
24821 }
24822
24823 Func in() {
24824 this->check_gio_access();
24825 return Func(*this).in();
24826 }
24827
24828 Func in(const Func &other) {
24829 this->check_gio_access();
24830 return Func(*this).in(other);
24831 }
24832
24833 Func in(const std::vector<Func> &others) {
24834 this->check_gio_access();
24835 return Func(*this).in(others);
24836 }
24837
24838 /** Forward const methods to the underlying Func. (Non-const methods
24839 * aren't available for Input<Func>.) */
24840 // @{
24841 HALIDE_FORWARD_METHOD_CONST(Func, args)
24842 HALIDE_FORWARD_METHOD_CONST(Func, defined)
24843 HALIDE_FORWARD_METHOD_CONST(Func, has_update_definition)
24844 HALIDE_FORWARD_METHOD_CONST(Func, num_update_definitions)
24845 HALIDE_FORWARD_METHOD_CONST(Func, output_types)
24846 HALIDE_FORWARD_METHOD_CONST(Func, outputs)
24847 HALIDE_FORWARD_METHOD_CONST(Func, rvars)
24848 HALIDE_FORWARD_METHOD_CONST(Func, update_args)
24849 HALIDE_FORWARD_METHOD_CONST(Func, update_value)
24850 HALIDE_FORWARD_METHOD_CONST(Func, update_values)
24851 HALIDE_FORWARD_METHOD_CONST(Func, value)
24852 HALIDE_FORWARD_METHOD_CONST(Func, values)
24853 // }@
24854};
24855
24856template<typename T>
24857class GeneratorInput_DynamicScalar : public GeneratorInputImpl<T, Expr> {
24858private:
24859 using Super = GeneratorInputImpl<T, Expr>;
24860
24861 static_assert(std::is_same<typename std::remove_all_extents<T>::type, Expr>::value, "GeneratorInput_DynamicScalar is only legal to use with T=Expr for now");
24862
24863protected:
24864 std::string get_c_type() const override {
24865 return "Expr";
24866 }
24867
24868public:
24869 explicit GeneratorInput_DynamicScalar(const std::string &name)
24870 : Super(name, IOKind::Scalar, {}, 0) {
24871 user_assert(!std::is_array<T>::value) << "Input<Expr[]> is not allowed";
24872 }
24873
24874 /** You can use this Input as an expression in a halide
24875 * function definition */
24876 operator Expr() const {
24877 this->check_gio_access();
24878 return this->exprs().at(0);
24879 }
24880
24881 /** Using an Input as the argument to an external stage treats it
24882 * as an Expr */
24883 operator ExternFuncArgument() const {
24884 this->check_gio_access();
24885 return ExternFuncArgument(this->exprs().at(0));
24886 }
24887
24888 void set_estimate(const Expr &value) {
24889 this->check_gio_access();
24890 for (Parameter &p : this->parameters_) {
24891 p.set_estimate(value);
24892 }
24893 }
24894};
24895
24896template<typename T>
24897class GeneratorInput_Scalar : public GeneratorInputImpl<T, Expr> {
24898private:
24899 using Super = GeneratorInputImpl<T, Expr>;
24900
24901protected:
24902 using TBase = typename Super::TBase;
24903
24904 const TBase def_{TBase()};
24905 const Expr def_expr_;
24906
24907 void set_def_min_max() override {
24908 for (Parameter &p : this->parameters_) {
24909 p.set_scalar<TBase>(def_);
24910 p.set_default_value(def_expr_);
24911 }
24912 }
24913
24914 std::string get_c_type() const override {
24915 return "Expr";
24916 }
24917
24918 // Expr() doesn't accept a pointer type in its ctor; add a SFINAE adapter
24919 // so that pointer (aka handle) Inputs will get cast to uint64.
24920 template<typename TBase2 = TBase, typename std::enable_if<!std::is_pointer<TBase2>::value>::type * = nullptr>
24921 static Expr TBaseToExpr(const TBase2 &value) {
24922 return cast<TBase>(Expr(value));
24923 }
24924
24925 template<typename TBase2 = TBase, typename std::enable_if<std::is_pointer<TBase2>::value>::type * = nullptr>
24926 static Expr TBaseToExpr(const TBase2 &value) {
24927 user_assert(value == 0) << "Zero is the only legal default value for Inputs which are pointer types.\n";
24928 return Expr();
24929 }
24930
24931public:
24932 explicit GeneratorInput_Scalar(const std::string &name)
24933 : Super(name, IOKind::Scalar, {type_of<TBase>()}, 0), def_(static_cast<TBase>(0)), def_expr_(Expr()) {
24934 }
24935
24936 GeneratorInput_Scalar(const std::string &name, const TBase &def)
24937 : Super(name, IOKind::Scalar, {type_of<TBase>()}, 0), def_(def), def_expr_(TBaseToExpr(def)) {
24938 }
24939
24940 GeneratorInput_Scalar(size_t array_size,
24941 const std::string &name)
24942 : Super(array_size, name, IOKind::Scalar, {type_of<TBase>()}, 0), def_(static_cast<TBase>(0)), def_expr_(Expr()) {
24943 }
24944
24945 GeneratorInput_Scalar(size_t array_size,
24946 const std::string &name,
24947 const TBase &def)
24948 : Super(array_size, name, IOKind::Scalar, {type_of<TBase>()}, 0), def_(def), def_expr_(TBaseToExpr(def)) {
24949 }
24950
24951 /** You can use this Input as an expression in a halide
24952 * function definition */
24953 operator Expr() const {
24954 this->check_gio_access();
24955 return this->exprs().at(0);
24956 }
24957
24958 /** Using an Input as the argument to an external stage treats it
24959 * as an Expr */
24960 operator ExternFuncArgument() const {
24961 this->check_gio_access();
24962 return ExternFuncArgument(this->exprs().at(0));
24963 }
24964
24965 template<typename T2 = T, typename std::enable_if<std::is_pointer<T2>::value>::type * = nullptr>
24966 void set_estimate(const TBase &value) {
24967 this->check_gio_access();
24968 user_assert(value == nullptr) << "nullptr is the only valid estimate for Input<PointerType>";
24969 Expr e = reinterpret(type_of<T2>(), cast<uint64_t>(0));
24970 for (Parameter &p : this->parameters_) {
24971 p.set_estimate(e);
24972 }
24973 }
24974
24975 template<typename T2 = T, typename std::enable_if<!std::is_array<T2>::value && !std::is_pointer<T2>::value>::type * = nullptr>
24976 void set_estimate(const TBase &value) {
24977 this->check_gio_access();
24978 Expr e = Expr(value);
24979 if (std::is_same<T2, bool>::value) {
24980 e = cast<bool>(e);
24981 }
24982 for (Parameter &p : this->parameters_) {
24983 p.set_estimate(e);
24984 }
24985 }
24986
24987 template<typename T2 = T, typename std::enable_if<std::is_array<T2>::value>::type * = nullptr>
24988 void set_estimate(size_t index, const TBase &value) {
24989 this->check_gio_access();
24990 Expr e = Expr(value);
24991 if (std::is_same<T2, bool>::value) {
24992 e = cast<bool>(e);
24993 }
24994 this->parameters_.at(index).set_estimate(e);
24995 }
24996};
24997
24998template<typename T>
24999class GeneratorInput_Arithmetic : public GeneratorInput_Scalar<T> {
25000private:
25001 using Super = GeneratorInput_Scalar<T>;
25002
25003protected:
25004 using TBase = typename Super::TBase;
25005
25006 const Expr min_, max_;
25007
25008 void set_def_min_max() override {
25009 Super::set_def_min_max();
25010 // Don't set min/max for bool
25011 if (!std::is_same<TBase, bool>::value) {
25012 for (Parameter &p : this->parameters_) {
25013 if (min_.defined()) {
25014 p.set_min_value(min_);
25015 }
25016 if (max_.defined()) {
25017 p.set_max_value(max_);
25018 }
25019 }
25020 }
25021 }
25022
25023public:
25024 explicit GeneratorInput_Arithmetic(const std::string &name)
25025 : Super(name), min_(Expr()), max_(Expr()) {
25026 }
25027
25028 GeneratorInput_Arithmetic(const std::string &name,
25029 const TBase &def)
25030 : Super(name, def), min_(Expr()), max_(Expr()) {
25031 }
25032
25033 GeneratorInput_Arithmetic(size_t array_size,
25034 const std::string &name)
25035 : Super(array_size, name), min_(Expr()), max_(Expr()) {
25036 }
25037
25038 GeneratorInput_Arithmetic(size_t array_size,
25039 const std::string &name,
25040 const TBase &def)
25041 : Super(array_size, name, def), min_(Expr()), max_(Expr()) {
25042 }
25043
25044 GeneratorInput_Arithmetic(const std::string &name,
25045 const TBase &def,
25046 const TBase &min,
25047 const TBase &max)
25048 : Super(name, def), min_(min), max_(max) {
25049 }
25050
25051 GeneratorInput_Arithmetic(size_t array_size,
25052 const std::string &name,
25053 const TBase &def,
25054 const TBase &min,
25055 const TBase &max)
25056 : Super(array_size, name, def), min_(min), max_(max) {
25057 }
25058};
25059
25060template<typename>
25061struct type_sink { typedef void type; };
25062
25063template<typename T2, typename = void>
25064struct has_static_halide_type_method : std::false_type {};
25065
25066template<typename T2>
25067struct has_static_halide_type_method<T2, typename type_sink<decltype(T2::static_halide_type())>::type> : std::true_type {};
25068
25069template<typename T, typename TBase = typename std::remove_all_extents<T>::type>
25070using GeneratorInputImplBase =
25071 typename select_type<
25072 cond<has_static_halide_type_method<TBase>::value, GeneratorInput_Buffer<T>>,
25073 cond<std::is_same<TBase, Func>::value, GeneratorInput_Func<T>>,
25074 cond<std::is_arithmetic<TBase>::value, GeneratorInput_Arithmetic<T>>,
25075 cond<std::is_scalar<TBase>::value, GeneratorInput_Scalar<T>>,
25076 cond<std::is_same<TBase, Expr>::value, GeneratorInput_DynamicScalar<T>>>::type;
25077
25078} // namespace Internal
25079
25080template<typename T>
25081class GeneratorInput : public Internal::GeneratorInputImplBase<T> {
25082private:
25083 using Super = Internal::GeneratorInputImplBase<T>;
25084
25085protected:
25086 using TBase = typename Super::TBase;
25087
25088 // Trick to avoid ambiguous ctor between Func-with-dim and int-with-default-value;
25089 // since we can't use std::enable_if on ctors, define the argument to be one that
25090 // can only be properly resolved for TBase=Func.
25091 struct Unused;
25092 using IntIfNonScalar =
25093 typename Internal::select_type<
25094 Internal::cond<Internal::has_static_halide_type_method<TBase>::value, int>,
25095 Internal::cond<std::is_same<TBase, Func>::value, int>,
25096 Internal::cond<true, Unused>>::type;
25097
25098public:
25099 explicit GeneratorInput(const std::string &name)
25100 : Super(name) {
25101 }
25102
25103 GeneratorInput(const std::string &name, const TBase &def)
25104 : Super(name, def) {
25105 }
25106
25107 GeneratorInput(size_t array_size, const std::string &name, const TBase &def)
25108 : Super(array_size, name, def) {
25109 }
25110
25111 GeneratorInput(const std::string &name,
25112 const TBase &def, const TBase &min, const TBase &max)
25113 : Super(name, def, min, max) {
25114 }
25115
25116 GeneratorInput(size_t array_size, const std::string &name,
25117 const TBase &def, const TBase &min, const TBase &max)
25118 : Super(array_size, name, def, min, max) {
25119 }
25120
25121 GeneratorInput(const std::string &name, const Type &t, int d)
25122 : Super(name, t, d) {
25123 }
25124
25125 GeneratorInput(const std::string &name, const Type &t)
25126 : Super(name, t) {
25127 }
25128
25129 // Avoid ambiguity between Func-with-dim and int-with-default
25130 GeneratorInput(const std::string &name, IntIfNonScalar d)
25131 : Super(name, d) {
25132 }
25133
25134 GeneratorInput(size_t array_size, const std::string &name, const Type &t, int d)
25135 : Super(array_size, name, t, d) {
25136 }
25137
25138 GeneratorInput(size_t array_size, const std::string &name, const Type &t)
25139 : Super(array_size, name, t) {
25140 }
25141
25142 // Avoid ambiguity between Func-with-dim and int-with-default
25143 //template <typename T2 = T, typename std::enable_if<std::is_same<TBase, Func>::value>::type * = nullptr>
25144 GeneratorInput(size_t array_size, const std::string &name, IntIfNonScalar d)
25145 : Super(array_size, name, d) {
25146 }
25147
25148 GeneratorInput(size_t array_size, const std::string &name)
25149 : Super(array_size, name) {
25150 }
25151};
25152
25153namespace Internal {
25154
25155class GeneratorOutputBase : public GIOBase {
25156protected:
25157 template<typename T2, typename std::enable_if<std::is_same<T2, Func>::value>::type * = nullptr>
25158 HALIDE_NO_USER_CODE_INLINE T2 as() const {
25159 static_assert(std::is_same<T2, Func>::value, "Only Func allowed here");
25160 internal_assert(kind() != IOKind::Scalar);
25161 internal_assert(exprs_.empty());
25162 user_assert(funcs_.size() == 1) << "Use [] to access individual Funcs in Output<Func[]>";
25163 return funcs_[0];
25164 }
25165
25166public:
25167 /** Forward schedule-related methods to the underlying Func. */
25168 // @{
25169 HALIDE_FORWARD_METHOD(Func, add_trace_tag)
25170 HALIDE_FORWARD_METHOD(Func, align_bounds)
25171 HALIDE_FORWARD_METHOD(Func, align_extent)
25172 HALIDE_FORWARD_METHOD(Func, align_storage)
25173 HALIDE_FORWARD_METHOD_CONST(Func, args)
25174 HALIDE_FORWARD_METHOD(Func, bound)
25175 HALIDE_FORWARD_METHOD(Func, bound_extent)
25176 HALIDE_FORWARD_METHOD(Func, compute_at)
25177 HALIDE_FORWARD_METHOD(Func, compute_inline)
25178 HALIDE_FORWARD_METHOD(Func, compute_root)
25179 HALIDE_FORWARD_METHOD(Func, compute_with)
25180 HALIDE_FORWARD_METHOD(Func, copy_to_device)
25181 HALIDE_FORWARD_METHOD(Func, copy_to_host)
25182 HALIDE_FORWARD_METHOD(Func, define_extern)
25183 HALIDE_FORWARD_METHOD_CONST(Func, defined)
25184 HALIDE_FORWARD_METHOD(Func, fold_storage)
25185 HALIDE_FORWARD_METHOD(Func, fuse)
25186 HALIDE_FORWARD_METHOD(Func, gpu)
25187 HALIDE_FORWARD_METHOD(Func, gpu_blocks)
25188 HALIDE_FORWARD_METHOD(Func, gpu_single_thread)
25189 HALIDE_FORWARD_METHOD(Func, gpu_threads)
25190 HALIDE_FORWARD_METHOD(Func, gpu_tile)
25191 HALIDE_FORWARD_METHOD_CONST(Func, has_update_definition)
25192 HALIDE_FORWARD_METHOD(Func, hexagon)
25193 HALIDE_FORWARD_METHOD(Func, in)
25194 HALIDE_FORWARD_METHOD(Func, memoize)
25195 HALIDE_FORWARD_METHOD_CONST(Func, num_update_definitions)
25196 HALIDE_FORWARD_METHOD_CONST(Func, output_types)
25197 HALIDE_FORWARD_METHOD_CONST(Func, outputs)
25198 HALIDE_FORWARD_METHOD(Func, parallel)
25199 HALIDE_FORWARD_METHOD(Func, prefetch)
25200 HALIDE_FORWARD_METHOD(Func, print_loop_nest)
25201 HALIDE_FORWARD_METHOD(Func, rename)
25202 HALIDE_FORWARD_METHOD(Func, reorder)
25203 HALIDE_FORWARD_METHOD(Func, reorder_storage)
25204 HALIDE_FORWARD_METHOD_CONST(Func, rvars)
25205 HALIDE_FORWARD_METHOD(Func, serial)
25206 HALIDE_FORWARD_METHOD(Func, set_estimate)
25207 HALIDE_FORWARD_METHOD(Func, specialize)
25208 HALIDE_FORWARD_METHOD(Func, specialize_fail)
25209 HALIDE_FORWARD_METHOD(Func, split)
25210 HALIDE_FORWARD_METHOD(Func, store_at)
25211 HALIDE_FORWARD_METHOD(Func, store_root)
25212 HALIDE_FORWARD_METHOD(Func, tile)
25213 HALIDE_FORWARD_METHOD(Func, trace_stores)
25214 HALIDE_FORWARD_METHOD(Func, unroll)
25215 HALIDE_FORWARD_METHOD(Func, update)
25216 HALIDE_FORWARD_METHOD_CONST(Func, update_args)
25217 HALIDE_FORWARD_METHOD_CONST(Func, update_value)
25218 HALIDE_FORWARD_METHOD_CONST(Func, update_values)
25219 HALIDE_FORWARD_METHOD_CONST(Func, value)
25220 HALIDE_FORWARD_METHOD_CONST(Func, values)
25221 HALIDE_FORWARD_METHOD(Func, vectorize)
25222 // }@
25223
25224#undef HALIDE_OUTPUT_FORWARD
25225#undef HALIDE_OUTPUT_FORWARD_CONST
25226
25227protected:
25228 GeneratorOutputBase(size_t array_size,
25229 const std::string &name,
25230 IOKind kind,
25231 const std::vector<Type> &t,
25232 int d);
25233
25234 GeneratorOutputBase(const std::string &name,
25235 IOKind kind,
25236 const std::vector<Type> &t,
25237 int d);
25238
25239 friend class GeneratorBase;
25240 friend class StubEmitter;
25241
25242 void init_internals();
25243 void resize(size_t size);
25244
25245 virtual std::string get_c_type() const {
25246 return "Func";
25247 }
25248
25249 void check_value_writable() const override;
25250
25251 const char *input_or_output() const override {
25252 return "Output";
25253 }
25254
25255public:
25256 ~GeneratorOutputBase() override;
25257};
25258
25259template<typename T>
25260class GeneratorOutputImpl : public GeneratorOutputBase {
25261protected:
25262 using TBase = typename std::remove_all_extents<T>::type;
25263 using ValueType = Func;
25264
25265 bool is_array() const override {
25266 return std::is_array<T>::value;
25267 }
25268
25269 template<typename T2 = T, typename std::enable_if<
25270 // Only allow T2 not-an-array
25271 !std::is_array<T2>::value>::type * = nullptr>
25272 GeneratorOutputImpl(const std::string &name, IOKind kind, const std::vector<Type> &t, int d)
25273 : GeneratorOutputBase(name, kind, t, d) {
25274 }
25275
25276 template<typename T2 = T, typename std::enable_if<
25277 // Only allow T2[kSomeConst]
25278 std::is_array<T2>::value && std::rank<T2>::value == 1 && (std::extent<T2, 0>::value > 0)>::type * = nullptr>
25279 GeneratorOutputImpl(const std::string &name, IOKind kind, const std::vector<Type> &t, int d)
25280 : GeneratorOutputBase(std::extent<T2, 0>::value, name, kind, t, d) {
25281 }
25282
25283 template<typename T2 = T, typename std::enable_if<
25284 // Only allow T2[]
25285 std::is_array<T2>::value && std::rank<T2>::value == 1 && std::extent<T2, 0>::value == 0>::type * = nullptr>
25286 GeneratorOutputImpl(const std::string &name, IOKind kind, const std::vector<Type> &t, int d)
25287 : GeneratorOutputBase(-1, name, kind, t, d) {
25288 }
25289
25290public:
25291 template<typename... Args, typename T2 = T, typename std::enable_if<!std::is_array<T2>::value>::type * = nullptr>
25292 FuncRef operator()(Args &&...args) const {
25293 this->check_gio_access();
25294 return get_values<ValueType>().at(0)(std::forward<Args>(args)...);
25295 }
25296
25297 template<typename ExprOrVar, typename T2 = T, typename std::enable_if<!std::is_array<T2>::value>::type * = nullptr>
25298 FuncRef operator()(std::vector<ExprOrVar> args) const {
25299 this->check_gio_access();
25300 return get_values<ValueType>().at(0)(args);
25301 }
25302
25303 template<typename T2 = T, typename std::enable_if<!std::is_array<T2>::value>::type * = nullptr>
25304 operator Func() const {
25305 this->check_gio_access();
25306 return get_values<ValueType>().at(0);
25307 }
25308
25309 template<typename T2 = T, typename std::enable_if<!std::is_array<T2>::value>::type * = nullptr>
25310 operator Stage() const {
25311 this->check_gio_access();
25312 return get_values<ValueType>().at(0);
25313 }
25314
25315 template<typename T2 = T, typename std::enable_if<std::is_array<T2>::value>::type * = nullptr>
25316 size_t size() const {
25317 this->check_gio_access();
25318 return get_values<ValueType>().size();
25319 }
25320
25321 template<typename T2 = T, typename std::enable_if<std::is_array<T2>::value>::type * = nullptr>
25322 const ValueType &operator[](size_t i) const {
25323 this->check_gio_access();
25324 return get_values<ValueType>()[i];
25325 }
25326
25327 template<typename T2 = T, typename std::enable_if<std::is_array<T2>::value>::type * = nullptr>
25328 const ValueType &at(size_t i) const {
25329 this->check_gio_access();
25330 return get_values<ValueType>().at(i);
25331 }
25332
25333 template<typename T2 = T, typename std::enable_if<std::is_array<T2>::value>::type * = nullptr>
25334 typename std::vector<ValueType>::const_iterator begin() const {
25335 this->check_gio_access();
25336 return get_values<ValueType>().begin();
25337 }
25338
25339 template<typename T2 = T, typename std::enable_if<std::is_array<T2>::value>::type * = nullptr>
25340 typename std::vector<ValueType>::const_iterator end() const {
25341 this->check_gio_access();
25342 return get_values<ValueType>().end();
25343 }
25344
25345 template<typename T2 = T, typename std::enable_if<
25346 // Only allow T2[]
25347 std::is_array<T2>::value && std::rank<T2>::value == 1 && std::extent<T2, 0>::value == 0>::type * = nullptr>
25348 void resize(size_t size) {
25349 this->check_gio_access();
25350 GeneratorOutputBase::resize(size);
25351 }
25352};
25353
25354template<typename T>
25355class GeneratorOutput_Buffer : public GeneratorOutputImpl<T> {
25356private:
25357 using Super = GeneratorOutputImpl<T>;
25358
25359 HALIDE_NO_USER_CODE_INLINE void assign_from_func(const Func &f) {
25360 this->check_value_writable();
25361
25362 internal_assert(f.defined());
25363
25364 if (this->types_defined()) {
25365 const auto &my_types = this->types();
25366 user_assert(my_types.size() == f.output_types().size())
25367 << "Cannot assign Func \"" << f.name()
25368 << "\" to Output \"" << this->name() << "\"\n"
25369 << "Output " << this->name()
25370 << " is declared to have " << my_types.size() << " tuple elements"
25371 << " but Func " << f.name()
25372 << " has " << f.output_types().size() << " tuple elements.\n";
25373 for (size_t i = 0; i < my_types.size(); i++) {
25374 user_assert(my_types[i] == f.output_types().at(i))
25375 << "Cannot assign Func \"" << f.name()
25376 << "\" to Output \"" << this->name() << "\"\n"
25377 << (my_types.size() > 1 ? "In tuple element " + std::to_string(i) + ", " : "")
25378 << "Output " << this->name()
25379 << " has declared type " << my_types[i]
25380 << " but Func " << f.name()
25381 << " has type " << f.output_types().at(i) << "\n";
25382 }
25383 }
25384 if (this->dims_defined()) {
25385 user_assert(f.dimensions() == this->dims())
25386 << "Cannot assign Func \"" << f.name()
25387 << "\" to Output \"" << this->name() << "\"\n"
25388 << "Output " << this->name()
25389 << " has declared dimensionality " << this->dims()
25390 << " but Func " << f.name()
25391 << " has dimensionality " << f.dimensions() << "\n";
25392 }
25393
25394 internal_assert(this->exprs_.empty() && this->funcs_.size() == 1);
25395 user_assert(!this->funcs_.at(0).defined());
25396 this->funcs_[0] = f;
25397 }
25398
25399protected:
25400 using TBase = typename Super::TBase;
25401
25402 static std::vector<Type> my_types(const std::vector<Type> &t) {
25403 if (TBase::has_static_halide_type) {
25404 user_assert(t.empty()) << "Cannot pass a Type argument for an Output<Buffer> with a non-void static type\n";
25405 return std::vector<Type>{TBase::static_halide_type()};
25406 }
25407 return t;
25408 }
25409
25410 GeneratorOutput_Buffer(const std::string &name, const std::vector<Type> &t = {}, int d = -1)
25411 : Super(name, IOKind::Buffer, my_types(t), d) {
25412 }
25413
25414 GeneratorOutput_Buffer(size_t array_size, const std::string &name, const std::vector<Type> &t = {}, int d = -1)
25415 : Super(array_size, name, IOKind::Buffer, my_types(t), d) {
25416 }
25417
25418 HALIDE_NO_USER_CODE_INLINE std::string get_c_type() const override {
25419 if (TBase::has_static_halide_type) {
25420 return "Halide::Internal::StubOutputBuffer<" +
25421 halide_type_to_c_type(TBase::static_halide_type()) +
25422 ">";
25423 } else {
25424 return "Halide::Internal::StubOutputBuffer<>";
25425 }
25426 }
25427
25428 template<typename T2, typename std::enable_if<!std::is_same<T2, Func>::value>::type * = nullptr>
25429 HALIDE_NO_USER_CODE_INLINE T2 as() const {
25430 return (T2) * this;
25431 }
25432
25433public:
25434 // Allow assignment from a Buffer<> to an Output<Buffer<>>;
25435 // this allows us to use a statically-compiled buffer inside a Generator
25436 // to assign to an output.
25437 // TODO: This used to take the buffer as a const ref. This no longer works as
25438 // using it in a Pipeline might change the dev field so it is currently
25439 // not considered const. We should consider how this really ought to work.
25440 template<typename T2>
25441 HALIDE_NO_USER_CODE_INLINE GeneratorOutput_Buffer<T> &operator=(Buffer<T2> &buffer) {
25442 this->check_gio_access();
25443 this->check_value_writable();
25444
25445 user_assert(T::can_convert_from(buffer))
25446 << "Cannot assign to the Output \"" << this->name()
25447 << "\": the expression is not convertible to the same Buffer type and/or dimensions.\n";
25448
25449 if (this->types_defined()) {
25450 user_assert(Type(buffer.type()) == this->type())
25451 << "Output " << this->name() << " should have type=" << this->type() << " but saw type=" << Type(buffer.type()) << "\n";
25452 }
25453 if (this->dims_defined()) {
25454 user_assert(buffer.dimensions() == this->dims())
25455 << "Output " << this->name() << " should have dim=" << this->dims() << " but saw dim=" << buffer.dimensions() << "\n";
25456 }
25457
25458 internal_assert(this->exprs_.empty() && this->funcs_.size() == 1);
25459 user_assert(!this->funcs_.at(0).defined());
25460 this->funcs_.at(0)(_) = buffer(_);
25461
25462 return *this;
25463 }
25464
25465 // Allow assignment from a StubOutputBuffer to an Output<Buffer>;
25466 // this allows us to pipeline the results of a Stub to the results
25467 // of the enclosing Generator.
25468 template<typename T2>
25469 GeneratorOutput_Buffer<T> &operator=(const StubOutputBuffer<T2> &stub_output_buffer) {
25470 this->check_gio_access();
25471 assign_from_func(stub_output_buffer.f);
25472 return *this;
25473 }
25474
25475 // Allow assignment from a Func to an Output<Buffer>;
25476 // this allows us to use helper functions that return a plain Func
25477 // to simply set the output(s) without needing a wrapper Func.
25478 GeneratorOutput_Buffer<T> &operator=(const Func &f) {
25479 this->check_gio_access();
25480 assign_from_func(f);
25481 return *this;
25482 }
25483
25484 operator OutputImageParam() const {
25485 this->check_gio_access();
25486 user_assert(!this->is_array()) << "Cannot convert an Output<Buffer<>[]> to an ImageParam; use an explicit subscript operator: " << this->name();
25487 internal_assert(this->exprs_.empty() && this->funcs_.size() == 1);
25488 return this->funcs_.at(0).output_buffer();
25489 }
25490
25491 // 'perfect forwarding' won't work with initializer lists,
25492 // so hand-roll our own forwarding method for set_estimates,
25493 // rather than using HALIDE_FORWARD_METHOD.
25494 GeneratorOutput_Buffer<T> &set_estimates(const Region &estimates) {
25495 this->as<OutputImageParam>().set_estimates(estimates);
25496 return *this;
25497 }
25498
25499 /** Forward methods to the OutputImageParam. */
25500 // @{
25501 HALIDE_FORWARD_METHOD(OutputImageParam, dim)
25502 HALIDE_FORWARD_METHOD_CONST(OutputImageParam, dim)
25503 HALIDE_FORWARD_METHOD_CONST(OutputImageParam, host_alignment)
25504 HALIDE_FORWARD_METHOD(OutputImageParam, set_host_alignment)
25505 HALIDE_FORWARD_METHOD(OutputImageParam, store_in)
25506 HALIDE_FORWARD_METHOD_CONST(OutputImageParam, dimensions)
25507 HALIDE_FORWARD_METHOD_CONST(OutputImageParam, left)
25508 HALIDE_FORWARD_METHOD_CONST(OutputImageParam, right)
25509 HALIDE_FORWARD_METHOD_CONST(OutputImageParam, top)
25510 HALIDE_FORWARD_METHOD_CONST(OutputImageParam, bottom)
25511 HALIDE_FORWARD_METHOD_CONST(OutputImageParam, width)
25512 HALIDE_FORWARD_METHOD_CONST(OutputImageParam, height)
25513 HALIDE_FORWARD_METHOD_CONST(OutputImageParam, channels)
25514 // }@
25515};
25516
25517template<typename T>
25518class GeneratorOutput_Func : public GeneratorOutputImpl<T> {
25519private:
25520 using Super = GeneratorOutputImpl<T>;
25521
25522 HALIDE_NO_USER_CODE_INLINE Func &get_assignable_func_ref(size_t i) {
25523 internal_assert(this->exprs_.empty() && this->funcs_.size() > i);
25524 return this->funcs_.at(i);
25525 }
25526
25527protected:
25528 using TBase = typename Super::TBase;
25529
25530 GeneratorOutput_Func(const std::string &name)
25531 : Super(name, IOKind::Function, std::vector<Type>{}, -1) {
25532 }
25533
25534 GeneratorOutput_Func(const std::string &name, const std::vector<Type> &t, int d = -1)
25535 : Super(name, IOKind::Function, t, d) {
25536 }
25537
25538 GeneratorOutput_Func(size_t array_size, const std::string &name, const std::vector<Type> &t, int d)
25539 : Super(array_size, name, IOKind::Function, t, d) {
25540 }
25541
25542public:
25543 // Allow Output<Func> = Func
25544 template<typename T2 = T, typename std::enable_if<!std::is_array<T2>::value>::type * = nullptr>
25545 GeneratorOutput_Func<T> &operator=(const Func &f) {
25546 this->check_gio_access();
25547 this->check_value_writable();
25548
25549 // Don't bother verifying the Func type, dimensions, etc., here:
25550 // That's done later, when we produce the pipeline.
25551 get_assignable_func_ref(0) = f;
25552 return *this;
25553 }
25554
25555 // Allow Output<Func[]> = Func
25556 template<typename T2 = T, typename std::enable_if<std::is_array<T2>::value>::type * = nullptr>
25557 Func &operator[](size_t i) {
25558 this->check_gio_access();
25559 this->check_value_writable();
25560 return get_assignable_func_ref(i);
25561 }
25562
25563 // Allow Func = Output<Func[]>
25564 template<typename T2 = T, typename std::enable_if<std::is_array<T2>::value>::type * = nullptr>
25565 const Func &operator[](size_t i) const {
25566 this->check_gio_access();
25567 return Super::operator[](i);
25568 }
25569
25570 GeneratorOutput_Func<T> &set_estimate(const Var &var, const Expr &min, const Expr &extent) {
25571 this->check_gio_access();
25572 internal_assert(this->exprs_.empty() && !this->funcs_.empty());
25573 for (Func &f : this->funcs_) {
25574 f.set_estimate(var, min, extent);
25575 }
25576 return *this;
25577 }
25578
25579 GeneratorOutput_Func<T> &set_estimates(const Region &estimates) {
25580 this->check_gio_access();
25581 internal_assert(this->exprs_.empty() && !this->funcs_.empty());
25582 for (Func &f : this->funcs_) {
25583 f.set_estimates(estimates);
25584 }
25585 return *this;
25586 }
25587};
25588
25589template<typename T>
25590class GeneratorOutput_Arithmetic : public GeneratorOutputImpl<T> {
25591private:
25592 using Super = GeneratorOutputImpl<T>;
25593
25594protected:
25595 using TBase = typename Super::TBase;
25596
25597 explicit GeneratorOutput_Arithmetic(const std::string &name)
25598 : Super(name, IOKind::Function, {type_of<TBase>()}, 0) {
25599 }
25600
25601 GeneratorOutput_Arithmetic(size_t array_size, const std::string &name)
25602 : Super(array_size, name, IOKind::Function, {type_of<TBase>()}, 0) {
25603 }
25604};
25605
25606template<typename T, typename TBase = typename std::remove_all_extents<T>::type>
25607using GeneratorOutputImplBase =
25608 typename select_type<
25609 cond<has_static_halide_type_method<TBase>::value, GeneratorOutput_Buffer<T>>,
25610 cond<std::is_same<TBase, Func>::value, GeneratorOutput_Func<T>>,
25611 cond<std::is_arithmetic<TBase>::value, GeneratorOutput_Arithmetic<T>>>::type;
25612
25613} // namespace Internal
25614
25615template<typename T>
25616class GeneratorOutput : public Internal::GeneratorOutputImplBase<T> {
25617private:
25618 using Super = Internal::GeneratorOutputImplBase<T>;
25619
25620protected:
25621 using TBase = typename Super::TBase;
25622
25623public:
25624 explicit GeneratorOutput(const std::string &name)
25625 : Super(name) {
25626 }
25627
25628 explicit GeneratorOutput(const char *name)
25629 : GeneratorOutput(std::string(name)) {
25630 }
25631
25632 GeneratorOutput(size_t array_size, const std::string &name)
25633 : Super(array_size, name) {
25634 }
25635
25636 GeneratorOutput(const std::string &name, int d)
25637 : Super(name, {}, d) {
25638 }
25639
25640 GeneratorOutput(const std::string &name, const Type &t, int d)
25641 : Super(name, {t}, d) {
25642 }
25643
25644 GeneratorOutput(const std::string &name, const std::vector<Type> &t, int d)
25645 : Super(name, t, d) {
25646 }
25647
25648 GeneratorOutput(size_t array_size, const std::string &name, int d)
25649 : Super(array_size, name, {}, d) {
25650 }
25651
25652 GeneratorOutput(size_t array_size, const std::string &name, const Type &t, int d)
25653 : Super(array_size, name, {t}, d) {
25654 }
25655
25656 GeneratorOutput(size_t array_size, const std::string &name, const std::vector<Type> &t, int d)
25657 : Super(array_size, name, t, d) {
25658 }
25659
25660 // TODO: This used to take the buffer as a const ref. This no longer works as
25661 // using it in a Pipeline might change the dev field so it is currently
25662 // not considered const. We should consider how this really ought to work.
25663 template<typename T2>
25664 GeneratorOutput<T> &operator=(Buffer<T2> &buffer) {
25665 Super::operator=(buffer);
25666 return *this;
25667 }
25668
25669 template<typename T2>
25670 GeneratorOutput<T> &operator=(const Internal::StubOutputBuffer<T2> &stub_output_buffer) {
25671 Super::operator=(stub_output_buffer);
25672 return *this;
25673 }
25674
25675 GeneratorOutput<T> &operator=(const Func &f) {
25676 Super::operator=(f);
25677 return *this;
25678 }
25679};
25680
25681namespace Internal {
25682
25683template<typename T>
25684T parse_scalar(const std::string &value) {
25685 std::istringstream iss(value);
25686 T t;
25687 iss >> t;
25688 user_assert(!iss.fail() && iss.get() == EOF) << "Unable to parse: " << value;
25689 return t;
25690}
25691
25692std::vector<Type> parse_halide_type_list(const std::string &types);
25693
25694enum class SyntheticParamType { Type,
25695 Dim,
25696 ArraySize };
25697
25698// This is a type of GeneratorParam used internally to create 'synthetic' params
25699// (e.g. image.type, image.dim); it is not possible for user code to instantiate it.
25700template<typename T>
25701class GeneratorParam_Synthetic : public GeneratorParamImpl<T> {
25702public:
25703 void set_from_string(const std::string &new_value_string) override {
25704 // If error_msg is not empty, this is unsettable:
25705 // display error_msg as a user error.
25706 if (!error_msg.empty()) {
25707 user_error << error_msg;
25708 }
25709 set_from_string_impl<T>(new_value_string);
25710 }
25711
25712 std::string get_default_value() const override {
25713 internal_error;
25714 return std::string();
25715 }
25716
25717 std::string call_to_string(const std::string &v) const override {
25718 internal_error;
25719 return std::string();
25720 }
25721
25722 std::string get_c_type() const override {
25723 internal_error;
25724 return std::string();
25725 }
25726
25727 bool is_synthetic_param() const override {
25728 return true;
25729 }
25730
25731private:
25732 friend class GeneratorParamInfo;
25733
25734 static std::unique_ptr<Internal::GeneratorParamBase> make(
25735 GeneratorBase *generator,
25736 const std::string &generator_name,
25737 const std::string &gpname,
25738 GIOBase &gio,
25739 SyntheticParamType which,
25740 bool defined) {
25741 std::string error_msg = defined ? "Cannot set the GeneratorParam " + gpname + " for " + generator_name + " because the value is explicitly specified in the C++ source." : "";
25742 return std::unique_ptr<GeneratorParam_Synthetic<T>>(
25743 new GeneratorParam_Synthetic<T>(gpname, gio, which, error_msg));
25744 }
25745
25746 GeneratorParam_Synthetic(const std::string &name, GIOBase &gio, SyntheticParamType which, const std::string &error_msg = "")
25747 : GeneratorParamImpl<T>(name, T()), gio(gio), which(which), error_msg(error_msg) {
25748 }
25749
25750 template<typename T2 = T, typename std::enable_if<std::is_same<T2, ::Halide::Type>::value>::type * = nullptr>
25751 void set_from_string_impl(const std::string &new_value_string) {
25752 internal_assert(which == SyntheticParamType::Type);
25753 gio.types_ = parse_halide_type_list(new_value_string);
25754 }
25755
25756 template<typename T2 = T, typename std::enable_if<std::is_integral<T2>::value>::type * = nullptr>
25757 void set_from_string_impl(const std::string &new_value_string) {
25758 if (which == SyntheticParamType::Dim) {
25759 gio.dims_ = parse_scalar<T2>(new_value_string);
25760 } else if (which == SyntheticParamType::ArraySize) {
25761 gio.array_size_ = parse_scalar<T2>(new_value_string);
25762 } else {
25763 internal_error;
25764 }
25765 }
25766
25767 GIOBase &gio;
25768 const SyntheticParamType which;
25769 const std::string error_msg;
25770};
25771
25772class GeneratorStub;
25773
25774} // namespace Internal
25775
25776/** GeneratorContext is a base class that is used when using Generators (or Stubs) directly;
25777 * it is used to allow the outer context (typically, either a Generator or "top-level" code)
25778 * to specify certain information to the inner context to ensure that inner and outer
25779 * Generators are compiled in a compatible way.
25780 *
25781 * If you are using this at "top level" (e.g. with the JIT), you can construct a GeneratorContext
25782 * with a Target:
25783 * \code
25784 * auto my_stub = MyStub(
25785 * GeneratorContext(get_target_from_environment()),
25786 * // inputs
25787 * { ... },
25788 * // generator params
25789 * { ... }
25790 * );
25791 * \endcode
25792 *
25793 * Note that all Generators inherit from GeneratorContext, so if you are using a Stub
25794 * from within a Generator, you can just pass 'this' for the GeneratorContext:
25795 * \code
25796 * struct SomeGen : Generator<SomeGen> {
25797 * void generate() {
25798 * ...
25799 * auto my_stub = MyStub(
25800 * this, // GeneratorContext
25801 * // inputs
25802 * { ... },
25803 * // generator params
25804 * { ... }
25805 * );
25806 * ...
25807 * }
25808 * };
25809 * \endcode
25810 */
25811class GeneratorContext {
25812public:
25813 using ExternsMap = std::map<std::string, ExternalCode>;
25814
25815 explicit GeneratorContext(const Target &t,
25816 bool auto_schedule = false,
25817 const MachineParams &machine_params = MachineParams::generic());
25818 virtual ~GeneratorContext() = default;
25819
25820 inline Target get_target() const {
25821 return target;
25822 }
25823 inline bool get_auto_schedule() const {
25824 return auto_schedule;
25825 }
25826 inline MachineParams get_machine_params() const {
25827 return machine_params;
25828 }
25829
25830 /** Generators can register ExternalCode objects onto
25831 * themselves. The Generator infrastructure will arrange to have
25832 * this ExternalCode appended to the Module that is finally
25833 * compiled using the Generator. This allows encapsulating
25834 * functionality that depends on external libraries or handwritten
25835 * code for various targets. The name argument should match the
25836 * name of the ExternalCode block and is used to ensure the same
25837 * code block is not duplicated in the output. Halide does not do
25838 * anything other than to compare names for equality. To guarantee
25839 * uniqueness in public code, we suggest using a Java style
25840 * inverted domain name followed by organization specific
25841 * naming. E.g.:
25842 * com.yoyodyne.overthruster.0719acd19b66df2a9d8d628a8fefba911a0ab2b7
25843 *
25844 * See test/generator/external_code_generator.cpp for example use. */
25845 inline std::shared_ptr<ExternsMap> get_externs_map() const {
25846 return externs_map;
25847 }
25848
25849 template<typename T>
25850 inline std::unique_ptr<T> create() const {
25851 return T::create(*this);
25852 }
25853
25854 template<typename T, typename... Args>
25855 inline std::unique_ptr<T> apply(const Args &...args) const {
25856 auto t = this->create<T>();
25857 t->apply(args...);
25858 return t;
25859 }
25860
25861protected:
25862 GeneratorParam<Target> target;
25863 GeneratorParam<bool> auto_schedule;
25864 GeneratorParam<MachineParams> machine_params;
25865 std::shared_ptr<ExternsMap> externs_map;
25866 std::shared_ptr<Internal::ValueTracker> value_tracker;
25867
25868 GeneratorContext()
25869 : GeneratorContext(Target()) {
25870 }
25871
25872 virtual void init_from_context(const Halide::GeneratorContext &context);
25873
25874 inline std::shared_ptr<Internal::ValueTracker> get_value_tracker() const {
25875 return value_tracker;
25876 }
25877
25878public:
25879 GeneratorContext(const GeneratorContext &) = delete;
25880 GeneratorContext &operator=(const GeneratorContext &) = delete;
25881 GeneratorContext(GeneratorContext &&) = delete;
25882 GeneratorContext &operator=(GeneratorContext &&) = delete;
25883};
25884
25885class NamesInterface {
25886 // Names in this class are only intended for use in derived classes.
25887protected:
25888 // Import a consistent list of Halide names that can be used in
25889 // Halide generators without qualification.
25890 using Expr = Halide::Expr;
25891 using EvictionKey = Halide::EvictionKey;
25892 using ExternFuncArgument = Halide::ExternFuncArgument;
25893 using Func = Halide::Func;
25894 using GeneratorContext = Halide::GeneratorContext;
25895 using ImageParam = Halide::ImageParam;
25896 using LoopLevel = Halide::LoopLevel;
25897 using MemoryType = Halide::MemoryType;
25898 using NameMangling = Halide::NameMangling;
25899 using Pipeline = Halide::Pipeline;
25900 using PrefetchBoundStrategy = Halide::PrefetchBoundStrategy;
25901 using RDom = Halide::RDom;
25902 using RVar = Halide::RVar;
25903 using TailStrategy = Halide::TailStrategy;
25904 using Target = Halide::Target;
25905 using Tuple = Halide::Tuple;
25906 using Type = Halide::Type;
25907 using Var = Halide::Var;
25908 template<typename T>
25909 static Expr cast(Expr e) {
25910 return Halide::cast<T>(e);
25911 }
25912 static inline Expr cast(Halide::Type t, Expr e) {
25913 return Halide::cast(t, std::move(e));
25914 }
25915 template<typename T>
25916 using GeneratorParam = Halide::GeneratorParam<T>;
25917 template<typename T = void>
25918 using Buffer = Halide::Buffer<T>;
25919 template<typename T>
25920 using Param = Halide::Param<T>;
25921 static inline Type Bool(int lanes = 1) {
25922 return Halide::Bool(lanes);
25923 }
25924 static inline Type Float(int bits, int lanes = 1) {
25925 return Halide::Float(bits, lanes);
25926 }
25927 static inline Type Int(int bits, int lanes = 1) {
25928 return Halide::Int(bits, lanes);
25929 }
25930 static inline Type UInt(int bits, int lanes = 1) {
25931 return Halide::UInt(bits, lanes);
25932 }
25933};
25934
25935namespace Internal {
25936
25937template<typename... Args>
25938struct NoRealizations : std::false_type {};
25939
25940template<>
25941struct NoRealizations<> : std::true_type {};
25942
25943template<typename T, typename... Args>
25944struct NoRealizations<T, Args...> {
25945 static const bool value = !std::is_convertible<T, Realization>::value && NoRealizations<Args...>::value;
25946};
25947
25948class GeneratorStub;
25949
25950// Note that these functions must never return null:
25951// if they cannot return a valid Generator, they must assert-fail.
25952using GeneratorFactory = std::function<std::unique_ptr<GeneratorBase>(const GeneratorContext &)>;
25953
25954struct StringOrLoopLevel {
25955 std::string string_value;
25956 LoopLevel loop_level;
25957
25958 StringOrLoopLevel() = default;
25959 /*not-explicit*/ StringOrLoopLevel(const char *s)
25960 : string_value(s) {
25961 }
25962 /*not-explicit*/ StringOrLoopLevel(const std::string &s)
25963 : string_value(s) {
25964 }
25965 /*not-explicit*/ StringOrLoopLevel(const LoopLevel &loop_level)
25966 : loop_level(loop_level) {
25967 }
25968};
25969using GeneratorParamsMap = std::map<std::string, StringOrLoopLevel>;
25970
25971class GeneratorParamInfo {
25972 // names used across all params, inputs, and outputs.
25973 std::set<std::string> names;
25974
25975 // Ordered-list of non-null ptrs to GeneratorParam<> fields.
25976 std::vector<Internal::GeneratorParamBase *> filter_generator_params;
25977
25978 // Ordered-list of non-null ptrs to Input<> fields.
25979 std::vector<Internal::GeneratorInputBase *> filter_inputs;
25980
25981 // Ordered-list of non-null ptrs to Output<> fields; empty if old-style Generator.
25982 std::vector<Internal::GeneratorOutputBase *> filter_outputs;
25983
25984 // list of synthetic GP's that we dynamically created; this list only exists to simplify
25985 // lifetime management, and shouldn't be accessed directly outside of our ctor/dtor,
25986 // regardless of friend access.
25987 std::vector<std::unique_ptr<Internal::GeneratorParamBase>> owned_synthetic_params;
25988
25989 // list of dynamically-added inputs and outputs, here only for lifetime management.
25990 std::vector<std::unique_ptr<Internal::GIOBase>> owned_extras;
25991
25992public:
25993 friend class GeneratorBase;
25994
25995 GeneratorParamInfo(GeneratorBase *generator, size_t size);
25996
25997 const std::vector<Internal::GeneratorParamBase *> &generator_params() const {
25998 return filter_generator_params;
25999 }
26000 const std::vector<Internal::GeneratorInputBase *> &inputs() const {
26001 return filter_inputs;
26002 }
26003 const std::vector<Internal::GeneratorOutputBase *> &outputs() const {
26004 return filter_outputs;
26005 }
26006};
26007
26008class GeneratorBase : public NamesInterface, public GeneratorContext {
26009public:
26010 ~GeneratorBase() override;
26011
26012 void set_generator_param_values(const GeneratorParamsMap &params);
26013
26014 /** Given a data type, return an estimate of the "natural" vector size
26015 * for that data type when compiling for the current target. */
26016 int natural_vector_size(Halide::Type t) const {
26017 return get_target().natural_vector_size(t);
26018 }
26019
26020 /** Given a data type, return an estimate of the "natural" vector size
26021 * for that data type when compiling for the current target. */
26022 template<typename data_t>
26023 int natural_vector_size() const {
26024 return get_target().natural_vector_size<data_t>();
26025 }
26026
26027 void emit_cpp_stub(const std::string &stub_file_path);
26028
26029 // Call build() and produce a Module for the result.
26030 // If function_name is empty, generator_name() will be used for the function.
26031 Module build_module(const std::string &function_name = "",
26032 LinkageType linkage_type = LinkageType::ExternalPlusMetadata);
26033
26034 /**
26035 * Build a module that is suitable for using for gradient descent calculation in TensorFlow or PyTorch.
26036 *
26037 * Essentially:
26038 * - A new Pipeline is synthesized from the current Generator (according to the rules below)
26039 * - The new Pipeline is autoscheduled (if autoscheduling is requested, but it would be odd not to do so)
26040 * - The Pipeline is compiled to a Module and returned
26041 *
26042 * The new Pipeline is adjoint to the original; it has:
26043 * - All the same inputs as the original, in the same order
26044 * - Followed by one grad-input for each original output
26045 * - Followed by one output for each unique pairing of original-output + original-input.
26046 * (For the common case of just one original-output, this amounts to being one output for each original-input.)
26047 */
26048 Module build_gradient_module(const std::string &function_name);
26049
26050 /**
26051 * set_inputs is a variadic wrapper around set_inputs_vector, which makes usage much simpler
26052 * in many cases, as it constructs the relevant entries for the vector for you, which
26053 * is often a bit unintuitive at present. The arguments are passed in Input<>-declaration-order,
26054 * and the types must be compatible. Array inputs are passed as std::vector<> of the relevant type.
26055 *
26056 * Note: at present, scalar input types must match *exactly*, i.e., for Input<uint8_t>, you
26057 * must pass an argument that is actually uint8_t; an argument that is int-that-will-fit-in-uint8
26058 * will assert-fail at Halide compile time.
26059 */
26060 template<typename... Args>
26061 void set_inputs(const Args &...args) {
26062 // set_inputs_vector() checks this too, but checking it here allows build_inputs() to avoid out-of-range checks.
26063 GeneratorParamInfo &pi = this->param_info();
26064 user_assert(sizeof...(args) == pi.inputs().size())
26065 << "Expected exactly " << pi.inputs().size()
26066 << " inputs but got " << sizeof...(args) << "\n";
26067 set_inputs_vector(build_inputs(std::forward_as_tuple<const Args &...>(args...), make_index_sequence<sizeof...(Args)>{}));
26068 }
26069
26070 Realization realize(std::vector<int32_t> sizes) {
26071 this->check_scheduled("realize");
26072 return get_pipeline().realize(std::move(sizes), get_target());
26073 }
26074
26075 // Only enable if none of the args are Realization; otherwise we can incorrectly
26076 // select this method instead of the Realization-as-outparam variant
26077 template<typename... Args, typename std::enable_if<NoRealizations<Args...>::value>::type * = nullptr>
26078 Realization realize(Args &&...args) {
26079 this->check_scheduled("realize");
26080 return get_pipeline().realize(std::forward<Args>(args)..., get_target());
26081 }
26082
26083 void realize(Realization r) {
26084 this->check_scheduled("realize");
26085 get_pipeline().realize(r, get_target());
26086 }
26087
26088 // Return the Pipeline that has been built by the generate() method.
26089 // This method can only be used from a Generator that has a generate()
26090 // method (vs a build() method), and currently can only be called from
26091 // the schedule() method. (This may be relaxed in the future to allow
26092 // calling from generate() as long as all Outputs have been defined.)
26093 Pipeline get_pipeline();
26094
26095 // Create Input<Buffer> or Input<Func> with dynamic type
26096 template<typename T,
26097 typename std::enable_if<!std::is_arithmetic<T>::value>::type * = nullptr>
26098 GeneratorInput<T> *add_input(const std::string &name, const Type &t, int dimensions) {
26099 check_exact_phase(GeneratorBase::ConfigureCalled);
26100 auto *p = new GeneratorInput<T>(name, t, dimensions);
26101 p->generator = this;
26102 param_info_ptr->owned_extras.push_back(std::unique_ptr<Internal::GIOBase>(p));
26103 param_info_ptr->filter_inputs.push_back(p);
26104 return p;
26105 }
26106
26107 // Create a Input<Buffer> or Input<Func> with compile-time type
26108 template<typename T,
26109 typename std::enable_if<T::has_static_halide_type>::type * = nullptr>
26110 GeneratorInput<T> *add_input(const std::string &name, int dimensions) {
26111 check_exact_phase(GeneratorBase::ConfigureCalled);
26112 auto *p = new GeneratorInput<T>(name, dimensions);
26113 p->generator = this;
26114 param_info_ptr->owned_extras.push_back(std::unique_ptr<Internal::GIOBase>(p));
26115 param_info_ptr->filter_inputs.push_back(p);
26116 return p;
26117 }
26118
26119 // Create Input<scalar>
26120 template<typename T,
26121 typename std::enable_if<std::is_arithmetic<T>::value>::type * = nullptr>
26122 GeneratorInput<T> *add_input(const std::string &name) {
26123 check_exact_phase(GeneratorBase::ConfigureCalled);
26124 auto *p = new GeneratorInput<T>(name);
26125 p->generator = this;
26126 param_info_ptr->owned_extras.push_back(std::unique_ptr<Internal::GIOBase>(p));
26127 param_info_ptr->filter_inputs.push_back(p);
26128 return p;
26129 }
26130
26131 // Create Input<Expr> with dynamic type
26132 template<typename T,
26133 typename std::enable_if<std::is_same<T, Expr>::value>::type * = nullptr>
26134 GeneratorInput<T> *add_input(const std::string &name, const Type &type) {
26135 check_exact_phase(GeneratorBase::ConfigureCalled);
26136 auto *p = new GeneratorInput<Expr>(name);
26137 p->generator = this;
26138 p->set_type(type);
26139 param_info_ptr->owned_extras.push_back(std::unique_ptr<Internal::GIOBase>(p));
26140 param_info_ptr->filter_inputs.push_back(p);
26141 return p;
26142 }
26143
26144 // Create Output<Buffer> or Output<Func> with dynamic type
26145 template<typename T,
26146 typename std::enable_if<!std::is_arithmetic<T>::value>::type * = nullptr>
26147 GeneratorOutput<T> *add_output(const std::string &name, const Type &t, int dimensions) {
26148 check_exact_phase(GeneratorBase::ConfigureCalled);
26149 auto *p = new GeneratorOutput<T>(name, t, dimensions);
26150 p->generator = this;
26151 param_info_ptr->owned_extras.push_back(std::unique_ptr<Internal::GIOBase>(p));
26152 param_info_ptr->filter_outputs.push_back(p);
26153 return p;
26154 }
26155
26156 // Create a Output<Buffer> or Output<Func> with compile-time type
26157 template<typename T,
26158 typename std::enable_if<T::has_static_halide_type>::type * = nullptr>
26159 GeneratorOutput<T> *add_output(const std::string &name, int dimensions) {
26160 check_exact_phase(GeneratorBase::ConfigureCalled);
26161 auto *p = new GeneratorOutput<T>(name, dimensions);
26162 p->generator = this;
26163 param_info_ptr->owned_extras.push_back(std::unique_ptr<Internal::GIOBase>(p));
26164 param_info_ptr->filter_outputs.push_back(p);
26165 return p;
26166 }
26167
26168 template<typename... Args>
26169 HALIDE_NO_USER_CODE_INLINE void add_requirement(Expr condition, Args &&...args) {
26170 get_pipeline().add_requirement(condition, std::forward<Args>(args)...);
26171 }
26172
26173 void trace_pipeline() {
26174 get_pipeline().trace_pipeline();
26175 }
26176
26177protected:
26178 GeneratorBase(size_t size, const void *introspection_helper);
26179 void set_generator_names(const std::string &registered_name, const std::string &stub_name);
26180
26181 void init_from_context(const Halide::GeneratorContext &context) override;
26182
26183 virtual Pipeline build_pipeline() = 0;
26184 virtual void call_configure() = 0;
26185 virtual void call_generate() = 0;
26186 virtual void call_schedule() = 0;
26187
26188 void track_parameter_values(bool include_outputs);
26189
26190 void pre_build();
26191 void post_build();
26192 void pre_configure();
26193 void post_configure();
26194 void pre_generate();
26195 void post_generate();
26196 void pre_schedule();
26197 void post_schedule();
26198
26199 template<typename T>
26200 using Input = GeneratorInput<T>;
26201
26202 template<typename T>
26203 using Output = GeneratorOutput<T>;
26204
26205 // A Generator's creation and usage must go in a certain phase to ensure correctness;
26206 // the state machine here is advanced and checked at various points to ensure
26207 // this is the case.
26208 enum Phase {
26209 // Generator has just come into being.
26210 Created,
26211
26212 // Generator has had its configure() method called. (For Generators without
26213 // a configure() method, this phase will be skipped and will advance
26214 // directly to InputsSet.)
26215 ConfigureCalled,
26216
26217 // All Input<>/Param<> fields have been set. (Applicable only in JIT mode;
26218 // in AOT mode, this can be skipped, going Created->GenerateCalled directly.)
26219 InputsSet,
26220
26221 // Generator has had its generate() method called. (For Generators with
26222 // a build() method instead of generate(), this phase will be skipped
26223 // and will advance directly to ScheduleCalled.)
26224 GenerateCalled,
26225
26226 // Generator has had its schedule() method (if any) called.
26227 ScheduleCalled,
26228 } phase{Created};
26229
26230 void check_exact_phase(Phase expected_phase) const;
26231 void check_min_phase(Phase expected_phase) const;
26232 void advance_phase(Phase new_phase);
26233
26234 void ensure_configure_has_been_called();
26235
26236private:
26237 friend void ::Halide::Internal::generator_test();
26238 friend class GeneratorParamBase;
26239 friend class GIOBase;
26240 friend class GeneratorInputBase;
26241 friend class GeneratorOutputBase;
26242 friend class GeneratorParamInfo;
26243 friend class GeneratorStub;
26244 friend class StubOutputBufferBase;
26245
26246 const size_t size;
26247
26248 // Lazily-allocated-and-inited struct with info about our various Params.
26249 // Do not access directly: use the param_info() getter.
26250 std::unique_ptr<GeneratorParamInfo> param_info_ptr;
26251
26252 mutable std::shared_ptr<ExternsMap> externs_map;
26253
26254 bool inputs_set{false};
26255 std::string generator_registered_name, generator_stub_name;
26256 Pipeline pipeline;
26257
26258 // Return our GeneratorParamInfo.
26259 GeneratorParamInfo &param_info();
26260
26261 Internal::GeneratorOutputBase *find_output_by_name(const std::string &name);
26262
26263 void check_scheduled(const char *m) const;
26264
26265 void build_params(bool force = false);
26266
26267 // Provide private, unimplemented, wrong-result-type methods here
26268 // so that Generators don't attempt to call the global methods
26269 // of the same name by accident: use the get_target() method instead.
26270 void get_host_target();
26271 void get_jit_target_from_environment();
26272 void get_target_from_environment();
26273
26274 // Return the output with the given name.
26275 // If the output is singular (a non-array), return a vector of size 1.
26276 // If no such name exists (or is non-array), assert.
26277 // This method never returns undefined Funcs.
26278 std::vector<Func> get_outputs(const std::string &n);
26279
26280 void set_inputs_vector(const std::vector<std::vector<StubInput>> &inputs);
26281
26282 static void check_input_is_singular(Internal::GeneratorInputBase *in);
26283 static void check_input_is_array(Internal::GeneratorInputBase *in);
26284 static void check_input_kind(Internal::GeneratorInputBase *in, Internal::IOKind kind);
26285
26286 // Allow Buffer<> if:
26287 // -- we are assigning it to an Input<Buffer<>> (with compatible type and dimensions),
26288 // causing the Input<Buffer<>> to become a precompiled buffer in the generated code.
26289 // -- we are assigningit to an Input<Func>, in which case we just Func-wrap the Buffer<>.
26290 template<typename T>
26291 std::vector<StubInput> build_input(size_t i, const Buffer<T> &arg) {
26292 auto *in = param_info().inputs().at(i);
26293 check_input_is_singular(in);
26294 const auto k = in->kind();
26295 if (k == Internal::IOKind::Buffer) {
26296 Halide::Buffer<> b = arg;
26297 StubInputBuffer<> sib(b);
26298 StubInput si(sib);
26299 return {si};
26300 } else if (k == Internal::IOKind::Function) {
26301 Halide::Func f(arg.name() + "_im");
26302 f(Halide::_) = arg(Halide::_);
26303 StubInput si(f);
26304 return {si};
26305 } else {
26306 check_input_kind(in, Internal::IOKind::Buffer); // just to trigger assertion
26307 return {};
26308 }
26309 }
26310
26311 // Allow Input<Buffer<>> if:
26312 // -- we are assigning it to another Input<Buffer<>> (with compatible type and dimensions),
26313 // allowing us to simply pipe a parameter from an enclosing Generator to the Invoker.
26314 // -- we are assigningit to an Input<Func>, in which case we just Func-wrap the Input<Buffer<>>.
26315 template<typename T>
26316 std::vector<StubInput> build_input(size_t i, const GeneratorInput<Buffer<T>> &arg) {
26317 auto *in = param_info().inputs().at(i);
26318 check_input_is_singular(in);
26319 const auto k = in->kind();
26320 if (k == Internal::IOKind::Buffer) {
26321 StubInputBuffer<> sib = arg;
26322 StubInput si(sib);
26323 return {si};
26324 } else if (k == Internal::IOKind::Function) {
26325 Halide::Func f = arg.funcs().at(0);
26326 StubInput si(f);
26327 return {si};
26328 } else {
26329 check_input_kind(in, Internal::IOKind::Buffer); // just to trigger assertion
26330 return {};
26331 }
26332 }
26333
26334 // Allow Func iff we are assigning it to an Input<Func> (with compatible type and dimensions).
26335 std::vector<StubInput> build_input(size_t i, const Func &arg) {
26336 auto *in = param_info().inputs().at(i);
26337 check_input_kind(in, Internal::IOKind::Function);
26338 check_input_is_singular(in);
26339 const Halide::Func &f = arg;
26340 StubInput si(f);
26341 return {si};
26342 }
26343
26344 // Allow vector<Func> iff we are assigning it to an Input<Func[]> (with compatible type and dimensions).
26345 std::vector<StubInput> build_input(size_t i, const std::vector<Func> &arg) {
26346 auto *in = param_info().inputs().at(i);
26347 check_input_kind(in, Internal::IOKind::Function);
26348 check_input_is_array(in);
26349 // My kingdom for a list comprehension...
26350 std::vector<StubInput> siv;
26351 siv.reserve(arg.size());
26352 for (const auto &f : arg) {
26353 siv.emplace_back(f);
26354 }
26355 return siv;
26356 }
26357
26358 // Expr must be Input<Scalar>.
26359 std::vector<StubInput> build_input(size_t i, const Expr &arg) {
26360 auto *in = param_info().inputs().at(i);
26361 check_input_kind(in, Internal::IOKind::Scalar);
26362 check_input_is_singular(in);
26363 StubInput si(arg);
26364 return {si};
26365 }
26366
26367 // (Array form)
26368 std::vector<StubInput> build_input(size_t i, const std::vector<Expr> &arg) {
26369 auto *in = param_info().inputs().at(i);
26370 check_input_kind(in, Internal::IOKind::Scalar);
26371 check_input_is_array(in);
26372 std::vector<StubInput> siv;
26373 siv.reserve(arg.size());
26374 for (const auto &value : arg) {
26375 siv.emplace_back(value);
26376 }
26377 return siv;
26378 }
26379
26380 // Any other type must be convertible to Expr and must be associated with an Input<Scalar>.
26381 // Use is_arithmetic since some Expr conversions are explicit.
26382 template<typename T,
26383 typename std::enable_if<std::is_arithmetic<T>::value>::type * = nullptr>
26384 std::vector<StubInput> build_input(size_t i, const T &arg) {
26385 auto *in = param_info().inputs().at(i);
26386 check_input_kind(in, Internal::IOKind::Scalar);
26387 check_input_is_singular(in);
26388 // We must use an explicit Expr() ctor to preserve the type
26389 Expr e(arg);
26390 StubInput si(e);
26391 return {si};
26392 }
26393
26394 // (Array form)
26395 template<typename T,
26396 typename std::enable_if<std::is_arithmetic<T>::value>::type * = nullptr>
26397 std::vector<StubInput> build_input(size_t i, const std::vector<T> &arg) {
26398 auto *in = param_info().inputs().at(i);
26399 check_input_kind(in, Internal::IOKind::Scalar);
26400 check_input_is_array(in);
26401 std::vector<StubInput> siv;
26402 siv.reserve(arg.size());
26403 for (const auto &value : arg) {
26404 // We must use an explicit Expr() ctor to preserve the type;
26405 // otherwise, implicit conversions can downgrade (e.g.) float -> int
26406 Expr e(value);
26407 siv.emplace_back(e);
26408 }
26409 return siv;
26410 }
26411
26412 template<typename... Args, size_t... Indices>
26413 std::vector<std::vector<StubInput>> build_inputs(const std::tuple<const Args &...> &t, index_sequence<Indices...>) {
26414 return {build_input(Indices, std::get<Indices>(t))...};
26415 }
26416
26417public:
26418 GeneratorBase(const GeneratorBase &) = delete;
26419 GeneratorBase &operator=(const GeneratorBase &) = delete;
26420 GeneratorBase(GeneratorBase &&that) = delete;
26421 GeneratorBase &operator=(GeneratorBase &&that) = delete;
26422};
26423
26424class GeneratorRegistry {
26425public:
26426 static void register_factory(const std::string &name, GeneratorFactory generator_factory);
26427 static void unregister_factory(const std::string &name);
26428 static std::vector<std::string> enumerate();
26429 // Note that this method will never return null:
26430 // if it cannot return a valid Generator, it should assert-fail.
26431 static std::unique_ptr<GeneratorBase> create(const std::string &name,
26432 const Halide::GeneratorContext &context);
26433
26434private:
26435 using GeneratorFactoryMap = std::map<const std::string, GeneratorFactory>;
26436
26437 GeneratorFactoryMap factories;
26438 std::mutex mutex;
26439
26440 static GeneratorRegistry &get_registry();
26441
26442 GeneratorRegistry() = default;
26443
26444public:
26445 GeneratorRegistry(const GeneratorRegistry &) = delete;
26446 GeneratorRegistry &operator=(const GeneratorRegistry &) = delete;
26447 GeneratorRegistry(GeneratorRegistry &&that) = delete;
26448 GeneratorRegistry &operator=(GeneratorRegistry &&that) = delete;
26449};
26450
26451} // namespace Internal
26452
26453template<class T>
26454class Generator : public Internal::GeneratorBase {
26455protected:
26456 Generator()
26457 : Internal::GeneratorBase(sizeof(T),
26458 Internal::Introspection::get_introspection_helper<T>()) {
26459 }
26460
26461public:
26462 static std::unique_ptr<T> create(const Halide::GeneratorContext &context) {
26463 // We must have an object of type T (not merely GeneratorBase) to call a protected method,
26464 // because CRTP is a weird beast.
26465 auto g = std::unique_ptr<T>(new T());
26466 g->init_from_context(context);
26467 return g;
26468 }
26469
26470 // This is public but intended only for use by the HALIDE_REGISTER_GENERATOR() macro.
26471 static std::unique_ptr<T> create(const Halide::GeneratorContext &context,
26472 const std::string &registered_name,
26473 const std::string &stub_name) {
26474 auto g = create(context);
26475 g->set_generator_names(registered_name, stub_name);
26476 return g;
26477 }
26478
26479 using Internal::GeneratorBase::apply;
26480 using Internal::GeneratorBase::create;
26481
26482 template<typename... Args>
26483 void apply(const Args &...args) {
26484#ifndef _MSC_VER
26485 // VS2015 apparently has some SFINAE issues, so this can inappropriately
26486 // trigger there. (We'll still fail when generate() is called, just
26487 // with a less-helpful error message.)
26488 static_assert(has_generate_method<T>::value, "apply() is not supported for old-style Generators.");
26489#endif
26490 call_configure();
26491 set_inputs(args...);
26492 call_generate();
26493 call_schedule();
26494 }
26495
26496private:
26497 // std::is_member_function_pointer will fail if there is no member of that name,
26498 // so we use a little SFINAE to detect if there are method-shaped members.
26499 template<typename>
26500 struct type_sink { typedef void type; };
26501
26502 template<typename T2, typename = void>
26503 struct has_configure_method : std::false_type {};
26504
26505 template<typename T2>
26506 struct has_configure_method<T2, typename type_sink<decltype(std::declval<T2>().configure())>::type> : std::true_type {};
26507
26508 template<typename T2, typename = void>
26509 struct has_generate_method : std::false_type {};
26510
26511 template<typename T2>
26512 struct has_generate_method<T2, typename type_sink<decltype(std::declval<T2>().generate())>::type> : std::true_type {};
26513
26514 template<typename T2, typename = void>
26515 struct has_schedule_method : std::false_type {};
26516
26517 template<typename T2>
26518 struct has_schedule_method<T2, typename type_sink<decltype(std::declval<T2>().schedule())>::type> : std::true_type {};
26519
26520 template<typename T2 = T,
26521 typename std::enable_if<!has_generate_method<T2>::value>::type * = nullptr>
26522
26523 // Implementations for build_pipeline_impl(), specialized on whether we
26524 // have build() or generate()/schedule() methods.
26525
26526 // MSVC apparently has some weirdness with the usual sfinae tricks
26527 // for detecting method-shaped things, so we can't actually use
26528 // the helpers above outside of static_assert. Instead we make as
26529 // many overloads as we can exist, and then use C++'s preference
26530 // for treating a 0 as an int rather than a double to choose one
26531 // of them.
26532 Pipeline build_pipeline_impl(double) {
26533 static_assert(!has_configure_method<T2>::value, "The configure() method is ignored if you define a build() method; use generate() instead.");
26534 static_assert(!has_schedule_method<T2>::value, "The schedule() method is ignored if you define a build() method; use generate() instead.");
26535 pre_build();
26536 Pipeline p = ((T *)this)->build();
26537 post_build();
26538 return p;
26539 }
26540
26541 template<typename T2 = T,
26542 typename = decltype(std::declval<T2>().generate())>
26543 Pipeline build_pipeline_impl(int) {
26544 // No: configure() must be called prior to this
26545 // (and in fact, prior to calling set_inputs).
26546 //
26547 // ((T *)this)->call_configure_impl(0, 0);
26548
26549 ((T *)this)->call_generate_impl(0);
26550 ((T *)this)->call_schedule_impl(0, 0);
26551 return get_pipeline();
26552 }
26553
26554 // Implementations for call_configure_impl(), specialized on whether we
26555 // have build() or configure()/generate()/schedule() methods.
26556
26557 void call_configure_impl(double, double) {
26558 pre_configure();
26559 // Called as a side effect for build()-method Generators; quietly do nothing
26560 // (except for pre_configure(), to advance the phase).
26561 post_configure();
26562 }
26563
26564 template<typename T2 = T,
26565 typename = decltype(std::declval<T2>().generate())>
26566 void call_configure_impl(double, int) {
26567 // Generator has a generate() method but no configure() method. This is ok. Just advance the phase.
26568 pre_configure();
26569 static_assert(!has_configure_method<T2>::value, "Did not expect a configure method here.");
26570 post_configure();
26571 }
26572
26573 template<typename T2 = T,
26574 typename = decltype(std::declval<T2>().generate()),
26575 typename = decltype(std::declval<T2>().configure())>
26576 void call_configure_impl(int, int) {
26577 T *t = (T *)this;
26578 static_assert(std::is_void<decltype(t->configure())>::value, "configure() must return void");
26579 pre_configure();
26580 t->configure();
26581 post_configure();
26582 }
26583
26584 // Implementations for call_generate_impl(), specialized on whether we
26585 // have build() or configure()/generate()/schedule() methods.
26586
26587 void call_generate_impl(double) {
26588 user_error << "Unimplemented";
26589 }
26590
26591 template<typename T2 = T,
26592 typename = decltype(std::declval<T2>().generate())>
26593 void call_generate_impl(int) {
26594 T *t = (T *)this;
26595 static_assert(std::is_void<decltype(t->generate())>::value, "generate() must return void");
26596 pre_generate();
26597 t->generate();
26598 post_generate();
26599 }
26600
26601 // Implementations for call_schedule_impl(), specialized on whether we
26602 // have build() or configure()generate()/schedule() methods.
26603
26604 void call_schedule_impl(double, double) {
26605 user_error << "Unimplemented";
26606 }
26607
26608 template<typename T2 = T,
26609 typename = decltype(std::declval<T2>().generate())>
26610 void call_schedule_impl(double, int) {
26611 // Generator has a generate() method but no schedule() method. This is ok. Just advance the phase.
26612 pre_schedule();
26613 post_schedule();
26614 }
26615
26616 template<typename T2 = T,
26617 typename = decltype(std::declval<T2>().generate()),
26618 typename = decltype(std::declval<T2>().schedule())>
26619 void call_schedule_impl(int, int) {
26620 T *t = (T *)this;
26621 static_assert(std::is_void<decltype(t->schedule())>::value, "schedule() must return void");
26622 pre_schedule();
26623 t->schedule();
26624 post_schedule();
26625 }
26626
26627protected:
26628 Pipeline build_pipeline() override {
26629 return this->build_pipeline_impl(0);
26630 }
26631
26632 void call_configure() override {
26633 this->call_configure_impl(0, 0);
26634 }
26635
26636 void call_generate() override {
26637 this->call_generate_impl(0);
26638 }
26639
26640 void call_schedule() override {
26641 this->call_schedule_impl(0, 0);
26642 }
26643
26644private:
26645 friend void ::Halide::Internal::generator_test();
26646 friend void ::Halide::Internal::generator_test();
26647 friend class ::Halide::GeneratorContext;
26648
26649public:
26650 Generator(const Generator &) = delete;
26651 Generator &operator=(const Generator &) = delete;
26652 Generator(Generator &&that) = delete;
26653 Generator &operator=(Generator &&that) = delete;
26654};
26655
26656namespace Internal {
26657
26658class RegisterGenerator {
26659public:
26660 RegisterGenerator(const char *registered_name, GeneratorFactory generator_factory);
26661};
26662
26663class GeneratorStub : public NamesInterface {
26664public:
26665 GeneratorStub(const GeneratorContext &context,
26666 const GeneratorFactory &generator_factory);
26667
26668 GeneratorStub(const GeneratorContext &context,
26669 const GeneratorFactory &generator_factory,
26670 const GeneratorParamsMap &generator_params,
26671 const std::vector<std::vector<Internal::StubInput>> &inputs);
26672 std::vector<std::vector<Func>> generate(const GeneratorParamsMap &generator_params,
26673 const std::vector<std::vector<Internal::StubInput>> &inputs);
26674
26675 // Output(s)
26676 std::vector<Func> get_outputs(const std::string &n) const {
26677 return generator->get_outputs(n);
26678 }
26679
26680 template<typename T2>
26681 std::vector<T2> get_output_buffers(const std::string &n) const {
26682 auto v = generator->get_outputs(n);
26683 std::vector<T2> result;
26684 for (auto &o : v) {
26685 result.push_back(T2(o, generator));
26686 }
26687 return result;
26688 }
26689
26690 static std::vector<StubInput> to_stub_input_vector(const Expr &e) {
26691 return {StubInput(e)};
26692 }
26693
26694 static std::vector<StubInput> to_stub_input_vector(const Func &f) {
26695 return {StubInput(f)};
26696 }
26697
26698 template<typename T = void>
26699 static std::vector<StubInput> to_stub_input_vector(const StubInputBuffer<T> &b) {
26700 return {StubInput(b)};
26701 }
26702
26703 template<typename T>
26704 static std::vector<StubInput> to_stub_input_vector(const std::vector<T> &v) {
26705 std::vector<StubInput> r;
26706 std::copy(v.begin(), v.end(), std::back_inserter(r));
26707 return r;
26708 }
26709
26710 struct Names {
26711 std::vector<std::string> generator_params, inputs, outputs;
26712 };
26713 Names get_names() const;
26714
26715 std::shared_ptr<GeneratorBase> generator;
26716};
26717
26718} // namespace Internal
26719
26720} // namespace Halide
26721
26722// Define this namespace at global scope so that anonymous namespaces won't
26723// defeat our static_assert check; define a dummy type inside so we can
26724// check for type aliasing injected by anonymous namespace usage
26725namespace halide_register_generator {
26726struct halide_global_ns;
26727};
26728
26729#define _HALIDE_REGISTER_GENERATOR_IMPL(GEN_CLASS_NAME, GEN_REGISTRY_NAME, FULLY_QUALIFIED_STUB_NAME) \
26730 namespace halide_register_generator { \
26731 struct halide_global_ns; \
26732 namespace GEN_REGISTRY_NAME##_ns { \
26733 std::unique_ptr<Halide::Internal::GeneratorBase> factory(const Halide::GeneratorContext &context); \
26734 std::unique_ptr<Halide::Internal::GeneratorBase> factory(const Halide::GeneratorContext &context) { \
26735 return GEN_CLASS_NAME::create(context, #GEN_REGISTRY_NAME, #FULLY_QUALIFIED_STUB_NAME); \
26736 } \
26737 } \
26738 static auto reg_##GEN_REGISTRY_NAME = Halide::Internal::RegisterGenerator(#GEN_REGISTRY_NAME, GEN_REGISTRY_NAME##_ns::factory); \
26739 } \
26740 static_assert(std::is_same<::halide_register_generator::halide_global_ns, halide_register_generator::halide_global_ns>::value, \
26741 "HALIDE_REGISTER_GENERATOR must be used at global scope");
26742
26743#define _HALIDE_REGISTER_GENERATOR2(GEN_CLASS_NAME, GEN_REGISTRY_NAME) \
26744 _HALIDE_REGISTER_GENERATOR_IMPL(GEN_CLASS_NAME, GEN_REGISTRY_NAME, GEN_REGISTRY_NAME)
26745
26746#define _HALIDE_REGISTER_GENERATOR3(GEN_CLASS_NAME, GEN_REGISTRY_NAME, FULLY_QUALIFIED_STUB_NAME) \
26747 _HALIDE_REGISTER_GENERATOR_IMPL(GEN_CLASS_NAME, GEN_REGISTRY_NAME, FULLY_QUALIFIED_STUB_NAME)
26748
26749// MSVC has a broken implementation of variadic macros: it expands __VA_ARGS__
26750// as a single token in argument lists (rather than multiple tokens).
26751// Jump through some hoops to work around this.
26752#define __HALIDE_REGISTER_ARGCOUNT_IMPL(_1, _2, _3, COUNT, ...) \
26753 COUNT
26754
26755#define _HALIDE_REGISTER_ARGCOUNT_IMPL(ARGS) \
26756 __HALIDE_REGISTER_ARGCOUNT_IMPL ARGS
26757
26758#define _HALIDE_REGISTER_ARGCOUNT(...) \
26759 _HALIDE_REGISTER_ARGCOUNT_IMPL((__VA_ARGS__, 3, 2, 1, 0))
26760
26761#define ___HALIDE_REGISTER_CHOOSER(COUNT) \
26762 _HALIDE_REGISTER_GENERATOR##COUNT
26763
26764#define __HALIDE_REGISTER_CHOOSER(COUNT) \
26765 ___HALIDE_REGISTER_CHOOSER(COUNT)
26766
26767#define _HALIDE_REGISTER_CHOOSER(COUNT) \
26768 __HALIDE_REGISTER_CHOOSER(COUNT)
26769
26770#define _HALIDE_REGISTER_GENERATOR_PASTE(A, B) \
26771 A B
26772
26773#define HALIDE_REGISTER_GENERATOR(...) \
26774 _HALIDE_REGISTER_GENERATOR_PASTE(_HALIDE_REGISTER_CHOOSER(_HALIDE_REGISTER_ARGCOUNT(__VA_ARGS__)), (__VA_ARGS__))
26775
26776// HALIDE_REGISTER_GENERATOR_ALIAS() can be used to create an an alias-with-a-particular-set-of-param-values
26777// for a given Generator in the build system. Normally, you wouldn't want to do this;
26778// however, some existing Halide clients have build systems that make it challenging to
26779// specify GeneratorParams inside the build system, and this allows a somewhat simpler
26780// customization route for them. It's highly recommended you don't use this for new code.
26781//
26782// The final argument is really an initializer-list of GeneratorParams, in the form
26783// of an initializer-list for map<string, string>:
26784//
26785// { { "gp-name", "gp-value"} [, { "gp2-name", "gp2-value" }] }
26786//
26787// It is specified as a variadic template argument to allow for the fact that the embedded commas
26788// would otherwise confuse the preprocessor; since (in this case) all we're going to do is
26789// pass it thru as-is, this is fine (and even MSVC's 'broken' __VA_ARGS__ should be OK here).
26790#define HALIDE_REGISTER_GENERATOR_ALIAS(GEN_REGISTRY_NAME, ORIGINAL_REGISTRY_NAME, ...) \
26791 namespace halide_register_generator { \
26792 struct halide_global_ns; \
26793 namespace ORIGINAL_REGISTRY_NAME##_ns { \
26794 std::unique_ptr<Halide::Internal::GeneratorBase> factory(const Halide::GeneratorContext &context); \
26795 } \
26796 namespace GEN_REGISTRY_NAME##_ns { \
26797 std::unique_ptr<Halide::Internal::GeneratorBase> factory(const Halide::GeneratorContext &context); \
26798 std::unique_ptr<Halide::Internal::GeneratorBase> factory(const Halide::GeneratorContext &context) { \
26799 auto g = ORIGINAL_REGISTRY_NAME##_ns::factory(context); \
26800 g->set_generator_param_values(__VA_ARGS__); \
26801 return g; \
26802 } \
26803 } \
26804 static auto reg_##GEN_REGISTRY_NAME = Halide::Internal::RegisterGenerator(#GEN_REGISTRY_NAME, GEN_REGISTRY_NAME##_ns::factory); \
26805 } \
26806 static_assert(std::is_same<::halide_register_generator::halide_global_ns, halide_register_generator::halide_global_ns>::value, \
26807 "HALIDE_REGISTER_GENERATOR_ALIAS must be used at global scope");
26808
26809#endif // HALIDE_GENERATOR_H_
26810#ifndef HALIDE_HEXAGON_OFFLOAD_H
26811#define HALIDE_HEXAGON_OFFLOAD_H
26812
26813/** \file
26814 * Defines a lowering pass to pull loops marked with the
26815 * Hexagon device API to a separate module, and call them through the
26816 * Hexagon host runtime module.
26817 */
26818
26819
26820namespace Halide {
26821
26822class Module;
26823struct Target;
26824
26825namespace Internal {
26826
26827/** Pull loops marked with the Hexagon device API to a separate
26828 * module, and call them through the Hexagon host runtime module. */
26829Stmt inject_hexagon_rpc(Stmt s, const Target &host_target, Module &module);
26830
26831Buffer<uint8_t> compile_module_to_hexagon_shared_object(const Module &device_code);
26832
26833} // namespace Internal
26834} // namespace Halide
26835
26836#endif
26837#ifndef HALIDE_IR_HEXAGON_OPTIMIZE_H
26838#define HALIDE_IR_HEXAGON_OPTIMIZE_H
26839
26840/** \file
26841 * Tools for optimizing IR for Hexagon.
26842 */
26843
26844
26845namespace Halide {
26846
26847struct Target;
26848
26849namespace Internal {
26850
26851/** Replace indirect and other loads with simple loads + vlut
26852 * calls. */
26853Stmt optimize_hexagon_shuffles(const Stmt &s, int lut_alignment);
26854
26855/* Generate vscatter-vgather instructions on Hexagon using VTCM memory.
26856 * The pass should be run before generating shuffles.
26857 * Some expressions which generate vscatter-vgathers are:
26858 * 1. out(x) = lut(foo(x)) -> vgather
26859 * 2. out(idx(x)) = foo(x) -> vscatter */
26860Stmt scatter_gather_generator(Stmt s);
26861
26862/** Hexagon deinterleaves when performing widening operations, and
26863 * interleaves when performing narrowing operations. This pass
26864 * rewrites widenings/narrowings to be explicit in the IR, and
26865 * attempts to simplify away most of the
26866 * interleaving/deinterleaving. */
26867Stmt optimize_hexagon_instructions(Stmt s, const Target &t);
26868
26869/** Generate deinterleave or interleave operations, operating on
26870 * groups of vectors at a time. */
26871//@{
26872Expr native_deinterleave(const Expr &x);
26873Expr native_interleave(const Expr &x);
26874bool is_native_deinterleave(const Expr &x);
26875bool is_native_interleave(const Expr &x);
26876//@}
26877
26878std::string type_suffix(Type type, bool signed_variants = true);
26879
26880std::string type_suffix(const Expr &a, bool signed_variants = true);
26881
26882std::string type_suffix(const Expr &a, const Expr &b, bool signed_variants = true);
26883
26884std::string type_suffix(const std::vector<Expr> &ops, bool signed_variants = true);
26885
26886} // namespace Internal
26887} // namespace Halide
26888
26889#endif
26890#ifndef HALIDE_INFER_ARGUMENTS_H
26891#define HALIDE_INFER_ARGUMENTS_H
26892
26893#include <vector>
26894
26895
26896/** \file
26897 *
26898 * Interface for a visitor to infer arguments used in a body Stmt.
26899 */
26900
26901namespace Halide {
26902namespace Internal {
26903
26904/** An inferred argument. Inferred args are either Params,
26905 * ImageParams, or Buffers. The first two are handled by the param
26906 * field, and global images are tracked via the buf field. These
26907 * are used directly when jitting, or used for validation when
26908 * compiling with an explicit argument list. */
26909struct InferredArgument {
26910 Argument arg;
26911 Parameter param;
26912 Buffer<> buffer;
26913
26914 bool operator<(const InferredArgument &other) const {
26915 if (arg.is_buffer() && !other.arg.is_buffer()) {
26916 return true;
26917 } else if (other.arg.is_buffer() && !arg.is_buffer()) {
26918 return false;
26919 } else {
26920 return arg.name < other.arg.name;
26921 }
26922 }
26923};
26924
26925class Function;
26926
26927std::vector<InferredArgument> infer_arguments(const Stmt &body, const std::vector<Function> &outputs);
26928
26929} // namespace Internal
26930} // namespace Halide
26931
26932#endif
26933#ifndef HALIDE_HOST_GPU_BUFFER_COPIES_H
26934#define HALIDE_HOST_GPU_BUFFER_COPIES_H
26935
26936/** \file
26937 * Defines the lowering passes that deal with host and device buffer flow.
26938 */
26939
26940#include <string>
26941#include <vector>
26942
26943
26944namespace Halide {
26945
26946struct Target;
26947
26948namespace Internal {
26949
26950/** A helper function to call an extern function, and assert that it
26951 * returns 0. */
26952Stmt call_extern_and_assert(const std::string &name, const std::vector<Expr> &args);
26953
26954/** Inject calls to halide_device_malloc, halide_copy_to_device, and
26955 * halide_copy_to_host as needed. */
26956Stmt inject_host_dev_buffer_copies(Stmt s, const Target &t);
26957
26958} // namespace Internal
26959} // namespace Halide
26960
26961#endif
26962#ifndef HALIDE_INLINE_H
26963#define HALIDE_INLINE_H
26964
26965/** \file
26966 * Methods for replacing calls to functions with their definitions.
26967 */
26968
26969
26970namespace Halide {
26971namespace Internal {
26972
26973class Function;
26974
26975/** Inline a single named function, which must be pure. For a pure function to
26976 * be inlined, it must not have any specializations (i.e. it can only have one
26977 * values definition). */
26978// @{
26979Stmt inline_function(Stmt s, const Function &f);
26980Expr inline_function(Expr e, const Function &f);
26981void inline_function(Function caller, const Function &f);
26982// @}
26983
26984/** Check if the schedule of an inlined function is legal, throwing an error
26985 * if it is not. */
26986void validate_schedule_inlined_function(Function f);
26987
26988} // namespace Internal
26989} // namespace Halide
26990
26991#endif
26992#ifndef HALIDE_INLINE_REDUCTIONS_H
26993#define HALIDE_INLINE_REDUCTIONS_H
26994
26995#include <string>
26996
26997
26998/** \file
26999 * Defines some inline reductions: sum, product, minimum, maximum.
27000 */
27001namespace Halide {
27002
27003class Func;
27004
27005/** An inline reduction. This is suitable for convolution-type
27006 * operations - the reduction will be computed in the innermost loop
27007 * that it is used in. The argument may contain free or implicit
27008 * variables, and must refer to some reduction domain. The free
27009 * variables are still free in the return value, but the reduction
27010 * domain is captured - the result expression does not refer to a
27011 * reduction domain and can be used in a pure function definition.
27012 *
27013 * An example using \ref sum :
27014 *
27015 \code
27016 Func f, g;
27017 Var x;
27018 RDom r(0, 10);
27019 f(x) = x*x;
27020 g(x) = sum(f(x + r));
27021 \endcode
27022 *
27023 * Here g computes some blur of x, but g is still a pure function. The
27024 * sum is being computed by an anonymous reduction function that is
27025 * scheduled innermost within g.
27026 */
27027//@{
27028Expr sum(Expr, const std::string &s = "sum");
27029Expr saturating_sum(Expr, const std::string &s = "saturating_sum");
27030Expr product(Expr, const std::string &s = "product");
27031Expr maximum(Expr, const std::string &s = "maximum");
27032Expr minimum(Expr, const std::string &s = "minimum");
27033//@}
27034
27035/** Variants of the inline reduction in which the RDom is stated
27036 * explicitly. The expression can refer to multiple RDoms, and only
27037 * the inner one is captured by the reduction. This allows you to
27038 * write expressions like:
27039 \code
27040 RDom r1(0, 10), r2(0, 10), r3(0, 10);
27041 Expr e = minimum(r1, product(r2, sum(r3, r1 + r2 + r3)));
27042 \endcode
27043*/
27044// @{
27045Expr sum(const RDom &, Expr, const std::string &s = "sum");
27046Expr saturating_sum(const RDom &r, Expr e, const std::string &s = "saturating_sum");
27047Expr product(const RDom &, Expr, const std::string &s = "product");
27048Expr maximum(const RDom &, Expr, const std::string &s = "maximum");
27049Expr minimum(const RDom &, Expr, const std::string &s = "minimum");
27050// @}
27051
27052/** Returns an Expr or Tuple representing the coordinates of the point
27053 * in the RDom which minimizes or maximizes the expression. The
27054 * expression must refer to some RDom. Also returns the extreme value
27055 * of the expression as the last element of the tuple. */
27056// @{
27057Tuple argmax(Expr, const std::string &s = "argmax");
27058Tuple argmin(Expr, const std::string &s = "argmin");
27059Tuple argmax(const RDom &, Expr, const std::string &s = "argmax");
27060Tuple argmin(const RDom &, Expr, const std::string &s = "argmin");
27061// @}
27062
27063/** Inline reductions create an anonymous helper Func to do the
27064 * work. The variants below instead take a named Func object to use,
27065 * so that it is no longer anonymous and can be scheduled
27066 * (e.g. unrolled across the reduction domain). The Func passed must
27067 * not have any existing definition. */
27068//@{
27069Expr sum(Expr, const Func &);
27070Expr saturating_sum(Expr, const Func &);
27071Expr product(Expr, const Func &);
27072Expr maximum(Expr, const Func &);
27073Expr minimum(Expr, const Func &);
27074Expr sum(const RDom &, Expr, const Func &);
27075Expr saturating_sum(const RDom &r, Expr e, const Func &);
27076Expr product(const RDom &, Expr, const Func &);
27077Expr maximum(const RDom &, Expr, const Func &);
27078Expr minimum(const RDom &, Expr, const Func &);
27079Tuple argmax(Expr, const Func &);
27080Tuple argmin(Expr, const Func &);
27081Tuple argmax(const RDom &, Expr, const Func &);
27082Tuple argmin(const RDom &, Expr, const Func &);
27083//@}
27084
27085} // namespace Halide
27086
27087#endif
27088#ifndef HALIDE_INTEGER_DIVISION_TABLE_H
27089#define HALIDE_INTEGER_DIVISION_TABLE_H
27090
27091#include <cstdint>
27092
27093/** \file
27094 * Tables telling us how to do integer division via fixed-point
27095 * multiplication for various small constants. This file is
27096 * automatically generated by find_inverse.cpp.
27097 */
27098namespace Halide {
27099namespace Internal {
27100namespace IntegerDivision {
27101extern const int64_t table_u8[256][4];
27102extern const int64_t table_s8[256][4];
27103extern const int64_t table_u16[256][4];
27104extern const int64_t table_s16[256][4];
27105extern const int64_t table_u32[256][4];
27106extern const int64_t table_s32[256][4];
27107extern const int64_t table_runtime_u8[256][4];
27108extern const int64_t table_runtime_s8[256][4];
27109extern const int64_t table_runtime_u16[256][4];
27110extern const int64_t table_runtime_s16[256][4];
27111extern const int64_t table_runtime_u32[256][4];
27112extern const int64_t table_runtime_s32[256][4];
27113} // namespace IntegerDivision
27114} // namespace Internal
27115} // namespace Halide
27116
27117#endif
27118#ifndef HALIDE_IR_MATCH_H
27119#define HALIDE_IR_MATCH_H
27120
27121/** \file
27122 * Defines a method to match a fragment of IR against a pattern containing wildcards
27123 */
27124
27125#include <map>
27126#include <random>
27127#include <set>
27128#include <vector>
27129
27130
27131namespace Halide {
27132namespace Internal {
27133
27134/** Does the first expression have the same structure as the second?
27135 * Variables in the first expression with the name * are interpreted
27136 * as wildcards, and their matching equivalent in the second
27137 * expression is placed in the vector give as the third argument.
27138 * Wildcards require the types to match. For the type bits and width,
27139 * a 0 indicates "match anything". So an Int(8, 0) will match 8-bit
27140 * integer vectors of any width (including scalars), and a UInt(0, 0)
27141 * will match any unsigned integer type.
27142 *
27143 * For example:
27144 \code
27145 Expr x = Variable::make(Int(32), "*");
27146 match(x + x, 3 + (2*k), result)
27147 \endcode
27148 * should return true, and set result[0] to 3 and
27149 * result[1] to 2*k.
27150 */
27151bool expr_match(const Expr &pattern, const Expr &expr, std::vector<Expr> &result);
27152
27153/** Does the first expression have the same structure as the second?
27154 * Variables are matched consistently. The first time a variable is
27155 * matched, it assumes the value of the matching part of the second
27156 * expression. Subsequent matches must be equal to the first match.
27157 *
27158 * For example:
27159 \code
27160 Var x("x"), y("y");
27161 match(x*(x + y), a*(a + b), result)
27162 \endcode
27163 * should return true, and set result["x"] = a, and result["y"] = b.
27164 */
27165bool expr_match(const Expr &pattern, const Expr &expr, std::map<std::string, Expr> &result);
27166
27167/** Rewrite the expression x to have `lanes` lanes. This is useful
27168 * for substituting the results of expr_match into a pattern expression. */
27169Expr with_lanes(const Expr &x, int lanes);
27170
27171void expr_match_test();
27172
27173/** An alternative template-metaprogramming approach to expression
27174 * matching. Potentially more efficient. We lift the expression
27175 * pattern into a type, and then use force-inlined functions to
27176 * generate efficient matching and reconstruction code for any
27177 * pattern. Pattern elements are either one of the classes in the
27178 * namespace IRMatcher, or are non-null Exprs (represented as
27179 * BaseExprNode &).
27180 *
27181 * Pattern elements that are fully specified by their pattern can be
27182 * built into an expression using the make method. Some patterns,
27183 * such as a broadcast that matches any number of lanes, don't have
27184 * enough information to recreate an Expr.
27185 */
27186namespace IRMatcher {
27187
27188constexpr int max_wild = 6;
27189
27190static const halide_type_t i64_type = {halide_type_int, 64, 1};
27191
27192/** To save stack space, the matcher objects are largely stateless and
27193 * immutable. This state object is built up during matching and then
27194 * consumed when constructing a replacement Expr.
27195 */
27196struct MatcherState {
27197 const BaseExprNode *bindings[max_wild];
27198 halide_scalar_value_t bound_const[max_wild];
27199
27200 // values of the lanes field with special meaning.
27201 static constexpr uint16_t signed_integer_overflow = 0x8000;
27202 static constexpr uint16_t special_values_mask = 0x8000; // currently only one
27203
27204 halide_type_t bound_const_type[max_wild];
27205
27206 HALIDE_ALWAYS_INLINE
27207 void set_binding(int i, const BaseExprNode &n) noexcept {
27208 bindings[i] = &n;
27209 }
27210
27211 HALIDE_ALWAYS_INLINE
27212 const BaseExprNode *get_binding(int i) const noexcept {
27213 return bindings[i];
27214 }
27215
27216 HALIDE_ALWAYS_INLINE
27217 void set_bound_const(int i, int64_t s, halide_type_t t) noexcept {
27218 bound_const[i].u.i64 = s;
27219 bound_const_type[i] = t;
27220 }
27221
27222 HALIDE_ALWAYS_INLINE
27223 void set_bound_const(int i, uint64_t u, halide_type_t t) noexcept {
27224 bound_const[i].u.u64 = u;
27225 bound_const_type[i] = t;
27226 }
27227
27228 HALIDE_ALWAYS_INLINE
27229 void set_bound_const(int i, double f, halide_type_t t) noexcept {
27230 bound_const[i].u.f64 = f;
27231 bound_const_type[i] = t;
27232 }
27233
27234 HALIDE_ALWAYS_INLINE
27235 void set_bound_const(int i, halide_scalar_value_t val, halide_type_t t) noexcept {
27236 bound_const[i] = val;
27237 bound_const_type[i] = t;
27238 }
27239
27240 HALIDE_ALWAYS_INLINE
27241 void get_bound_const(int i, halide_scalar_value_t &val, halide_type_t &type) const noexcept {
27242 val = bound_const[i];
27243 type = bound_const_type[i];
27244 }
27245
27246 HALIDE_ALWAYS_INLINE
27247 // NOLINTNEXTLINE(modernize-use-equals-default): Can't use `= default`; clang-tidy complains about noexcept mismatch
27248 MatcherState() noexcept {
27249 }
27250};
27251
27252template<typename T,
27253 typename = typename std::remove_reference<T>::type::pattern_tag>
27254struct enable_if_pattern {
27255 struct type {};
27256};
27257
27258template<typename T>
27259struct bindings {
27260 constexpr static uint32_t mask = std::remove_reference<T>::type::binds;
27261};
27262
27263inline HALIDE_NEVER_INLINE Expr make_const_special_expr(halide_type_t ty) {
27264 const uint16_t flags = ty.lanes & MatcherState::special_values_mask;
27265 ty.lanes &= ~MatcherState::special_values_mask;
27266 if (flags & MatcherState::signed_integer_overflow) {
27267 return make_signed_integer_overflow(ty);
27268 }
27269 // unreachable
27270 return Expr();
27271}
27272
27273HALIDE_ALWAYS_INLINE
27274Expr make_const_expr(halide_scalar_value_t val, halide_type_t ty) {
27275 halide_type_t scalar_type = ty;
27276 if (scalar_type.lanes & MatcherState::special_values_mask) {
27277 return make_const_special_expr(scalar_type);
27278 }
27279
27280 const int lanes = scalar_type.lanes;
27281 scalar_type.lanes = 1;
27282
27283 Expr e;
27284 switch (scalar_type.code) {
27285 case halide_type_int:
27286 e = IntImm::make(scalar_type, val.u.i64);
27287 break;
27288 case halide_type_uint:
27289 e = UIntImm::make(scalar_type, val.u.u64);
27290 break;
27291 case halide_type_float:
27292 case halide_type_bfloat:
27293 e = FloatImm::make(scalar_type, val.u.f64);
27294 break;
27295 default:
27296 // Unreachable
27297 return Expr();
27298 }
27299 if (lanes > 1) {
27300 e = Broadcast::make(e, lanes);
27301 }
27302 return e;
27303}
27304
27305bool equal_helper(const BaseExprNode &a, const BaseExprNode &b) noexcept;
27306
27307// A fast version of expression equality that assumes a well-typed non-null expression tree.
27308HALIDE_ALWAYS_INLINE
27309bool equal(const BaseExprNode &a, const BaseExprNode &b) noexcept {
27310 // Early out
27311 return (&a == &b) ||
27312 ((a.type == b.type) &&
27313 (a.node_type == b.node_type) &&
27314 equal_helper(a, b));
27315}
27316
27317// A pattern that matches a specific expression
27318struct SpecificExpr {
27319 struct pattern_tag {};
27320
27321 constexpr static uint32_t binds = 0;
27322
27323 // What is the weakest and strongest IR node this could possibly be
27324 constexpr static IRNodeType min_node_type = IRNodeType::IntImm;
27325 constexpr static IRNodeType max_node_type = IRNodeType::Shuffle;
27326 constexpr static bool canonical = true;
27327
27328 const BaseExprNode &expr;
27329
27330 template<uint32_t bound>
27331 HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
27332 return equal(expr, e);
27333 }
27334
27335 HALIDE_ALWAYS_INLINE
27336 Expr make(MatcherState &state, halide_type_t type_hint) const {
27337 return Expr(&expr);
27338 }
27339
27340 constexpr static bool foldable = false;
27341};
27342
27343inline std::ostream &operator<<(std::ostream &s, const SpecificExpr &e) {
27344 s << Expr(&e.expr);
27345 return s;
27346}
27347
27348template<int i>
27349struct WildConstInt {
27350 struct pattern_tag {};
27351
27352 constexpr static uint32_t binds = 1 << i;
27353
27354 constexpr static IRNodeType min_node_type = IRNodeType::IntImm;
27355 constexpr static IRNodeType max_node_type = IRNodeType::IntImm;
27356 constexpr static bool canonical = true;
27357
27358 template<uint32_t bound>
27359 HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
27360 static_assert(i >= 0 && i < max_wild, "Wild with out-of-range index");
27361 const BaseExprNode *op = &e;
27362 if (op->node_type == IRNodeType::Broadcast) {
27363 op = ((const Broadcast *)op)->value.get();
27364 }
27365 if (op->node_type != IRNodeType::IntImm) {
27366 return false;
27367 }
27368 int64_t value = ((const IntImm *)op)->value;
27369 if (bound & binds) {
27370 halide_scalar_value_t val;
27371 halide_type_t type;
27372 state.get_bound_const(i, val, type);
27373 return (halide_type_t)e.type == type && value == val.u.i64;
27374 }
27375 state.set_bound_const(i, value, e.type);
27376 return true;
27377 }
27378
27379 template<uint32_t bound>
27380 HALIDE_ALWAYS_INLINE bool match(int64_t value, MatcherState &state) const noexcept {
27381 static_assert(i >= 0 && i < max_wild, "Wild with out-of-range index");
27382 if (bound & binds) {
27383 halide_scalar_value_t val;
27384 halide_type_t type;
27385 state.get_bound_const(i, val, type);
27386 return type == i64_type && value == val.u.i64;
27387 }
27388 state.set_bound_const(i, value, i64_type);
27389 return true;
27390 }
27391
27392 HALIDE_ALWAYS_INLINE
27393 Expr make(MatcherState &state, halide_type_t type_hint) const {
27394 halide_scalar_value_t val;
27395 halide_type_t type;
27396 state.get_bound_const(i, val, type);
27397 return make_const_expr(val, type);
27398 }
27399
27400 constexpr static bool foldable = true;
27401
27402 HALIDE_ALWAYS_INLINE
27403 void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const {
27404 state.get_bound_const(i, val, ty);
27405 }
27406};
27407
27408template<int i>
27409std::ostream &operator<<(std::ostream &s, const WildConstInt<i> &c) {
27410 s << "ci" << i;
27411 return s;
27412}
27413
27414template<int i>
27415struct WildConstUInt {
27416 struct pattern_tag {};
27417
27418 constexpr static uint32_t binds = 1 << i;
27419
27420 constexpr static IRNodeType min_node_type = IRNodeType::UIntImm;
27421 constexpr static IRNodeType max_node_type = IRNodeType::UIntImm;
27422 constexpr static bool canonical = true;
27423
27424 template<uint32_t bound>
27425 HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
27426 static_assert(i >= 0 && i < max_wild, "Wild with out-of-range index");
27427 const BaseExprNode *op = &e;
27428 if (op->node_type == IRNodeType::Broadcast) {
27429 op = ((const Broadcast *)op)->value.get();
27430 }
27431 if (op->node_type != IRNodeType::UIntImm) {
27432 return false;
27433 }
27434 uint64_t value = ((const UIntImm *)op)->value;
27435 if (bound & binds) {
27436 halide_scalar_value_t val;
27437 halide_type_t type;
27438 state.get_bound_const(i, val, type);
27439 return (halide_type_t)e.type == type && value == val.u.u64;
27440 }
27441 state.set_bound_const(i, value, e.type);
27442 return true;
27443 }
27444
27445 HALIDE_ALWAYS_INLINE
27446 Expr make(MatcherState &state, halide_type_t type_hint) const {
27447 halide_scalar_value_t val;
27448 halide_type_t type;
27449 state.get_bound_const(i, val, type);
27450 return make_const_expr(val, type);
27451 }
27452
27453 constexpr static bool foldable = true;
27454
27455 HALIDE_ALWAYS_INLINE
27456 void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept {
27457 state.get_bound_const(i, val, ty);
27458 }
27459};
27460
27461template<int i>
27462std::ostream &operator<<(std::ostream &s, const WildConstUInt<i> &c) {
27463 s << "cu" << i;
27464 return s;
27465}
27466
27467template<int i>
27468struct WildConstFloat {
27469 struct pattern_tag {};
27470
27471 constexpr static uint32_t binds = 1 << i;
27472
27473 constexpr static IRNodeType min_node_type = IRNodeType::FloatImm;
27474 constexpr static IRNodeType max_node_type = IRNodeType::FloatImm;
27475 constexpr static bool canonical = true;
27476
27477 template<uint32_t bound>
27478 HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
27479 static_assert(i >= 0 && i < max_wild, "Wild with out-of-range index");
27480 const BaseExprNode *op = &e;
27481 if (op->node_type == IRNodeType::Broadcast) {
27482 op = ((const Broadcast *)op)->value.get();
27483 }
27484 if (op->node_type != IRNodeType::FloatImm) {
27485 return false;
27486 }
27487 double value = ((const FloatImm *)op)->value;
27488 if (bound & binds) {
27489 halide_scalar_value_t val;
27490 halide_type_t type;
27491 state.get_bound_const(i, val, type);
27492 return (halide_type_t)e.type == type && value == val.u.f64;
27493 }
27494 state.set_bound_const(i, value, e.type);
27495 return true;
27496 }
27497
27498 HALIDE_ALWAYS_INLINE
27499 Expr make(MatcherState &state, halide_type_t type_hint) const {
27500 halide_scalar_value_t val;
27501 halide_type_t type;
27502 state.get_bound_const(i, val, type);
27503 return make_const_expr(val, type);
27504 }
27505
27506 constexpr static bool foldable = true;
27507
27508 HALIDE_ALWAYS_INLINE
27509 void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept {
27510 state.get_bound_const(i, val, ty);
27511 }
27512};
27513
27514template<int i>
27515std::ostream &operator<<(std::ostream &s, const WildConstFloat<i> &c) {
27516 s << "cf" << i;
27517 return s;
27518}
27519
27520// Matches and binds to any constant Expr. Does not support constant-folding.
27521template<int i>
27522struct WildConst {
27523 struct pattern_tag {};
27524
27525 constexpr static uint32_t binds = 1 << i;
27526
27527 constexpr static IRNodeType min_node_type = IRNodeType::IntImm;
27528 constexpr static IRNodeType max_node_type = IRNodeType::FloatImm;
27529 constexpr static bool canonical = true;
27530
27531 template<uint32_t bound>
27532 HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
27533 static_assert(i >= 0 && i < max_wild, "Wild with out-of-range index");
27534 const BaseExprNode *op = &e;
27535 if (op->node_type == IRNodeType::Broadcast) {
27536 op = ((const Broadcast *)op)->value.get();
27537 }
27538 switch (op->node_type) {
27539 case IRNodeType::IntImm:
27540 return WildConstInt<i>().template match<bound>(e, state);
27541 case IRNodeType::UIntImm:
27542 return WildConstUInt<i>().template match<bound>(e, state);
27543 case IRNodeType::FloatImm:
27544 return WildConstFloat<i>().template match<bound>(e, state);
27545 default:
27546 return false;
27547 }
27548 }
27549
27550 template<uint32_t bound>
27551 HALIDE_ALWAYS_INLINE bool match(int64_t e, MatcherState &state) const noexcept {
27552 static_assert(i >= 0 && i < max_wild, "Wild with out-of-range index");
27553 return WildConstInt<i>().template match<bound>(e, state);
27554 }
27555
27556 HALIDE_ALWAYS_INLINE
27557 Expr make(MatcherState &state, halide_type_t type_hint) const {
27558 halide_scalar_value_t val;
27559 halide_type_t type;
27560 state.get_bound_const(i, val, type);
27561 return make_const_expr(val, type);
27562 }
27563
27564 constexpr static bool foldable = true;
27565
27566 HALIDE_ALWAYS_INLINE
27567 void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept {
27568 state.get_bound_const(i, val, ty);
27569 }
27570};
27571
27572template<int i>
27573std::ostream &operator<<(std::ostream &s, const WildConst<i> &c) {
27574 s << "c" << i;
27575 return s;
27576}
27577
27578// Matches and binds to any Expr
27579template<int i>
27580struct Wild {
27581 struct pattern_tag {};
27582
27583 constexpr static uint32_t binds = 1 << (i + 16);
27584
27585 constexpr static IRNodeType min_node_type = IRNodeType::IntImm;
27586 constexpr static IRNodeType max_node_type = StrongestExprNodeType;
27587 constexpr static bool canonical = true;
27588
27589 template<uint32_t bound>
27590 HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
27591 if (bound & binds) {
27592 return equal(*state.get_binding(i), e);
27593 }
27594 state.set_binding(i, e);
27595 return true;
27596 }
27597
27598 HALIDE_ALWAYS_INLINE
27599 Expr make(MatcherState &state, halide_type_t type_hint) const {
27600 return state.get_binding(i);
27601 }
27602
27603 constexpr static bool foldable = true;
27604 HALIDE_ALWAYS_INLINE
27605 void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept {
27606 const auto *e = state.get_binding(i);
27607 ty = e->type;
27608 switch (e->node_type) {
27609 case IRNodeType::UIntImm:
27610 val.u.u64 = ((const UIntImm *)e)->value;
27611 return;
27612 case IRNodeType::IntImm:
27613 val.u.i64 = ((const IntImm *)e)->value;
27614 return;
27615 case IRNodeType::FloatImm:
27616 val.u.f64 = ((const FloatImm *)e)->value;
27617 return;
27618 default:
27619 // The function is noexcept, so silent failure. You
27620 // shouldn't be calling this if you haven't already
27621 // checked it's going to be a constant (e.g. with
27622 // is_const, or because you manually bound a constant Expr
27623 // to the state).
27624 val.u.u64 = 0;
27625 }
27626 }
27627};
27628
27629template<int i>
27630std::ostream &operator<<(std::ostream &s, const Wild<i> &op) {
27631 s << "_" << i;
27632 return s;
27633}
27634
27635// Matches a specific constant or broadcast of that constant. The
27636// constant must be representable as an int64_t.
27637struct IntLiteral {
27638 struct pattern_tag {};
27639 int64_t v;
27640
27641 constexpr static uint32_t binds = 0;
27642
27643 constexpr static IRNodeType min_node_type = IRNodeType::IntImm;
27644 constexpr static IRNodeType max_node_type = IRNodeType::FloatImm;
27645 constexpr static bool canonical = true;
27646
27647 HALIDE_ALWAYS_INLINE
27648 explicit IntLiteral(int64_t v)
27649 : v(v) {
27650 }
27651
27652 template<uint32_t bound>
27653 HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
27654 const BaseExprNode *op = &e;
27655 if (e.node_type == IRNodeType::Broadcast) {
27656 op = ((const Broadcast *)op)->value.get();
27657 }
27658 switch (op->node_type) {
27659 case IRNodeType::IntImm:
27660 return ((const IntImm *)op)->value == (int64_t)v;
27661 case IRNodeType::UIntImm:
27662 return ((const UIntImm *)op)->value == (uint64_t)v;
27663 case IRNodeType::FloatImm:
27664 return ((const FloatImm *)op)->value == (double)v;
27665 default:
27666 return false;
27667 }
27668 }
27669
27670 template<uint32_t bound>
27671 HALIDE_ALWAYS_INLINE bool match(int64_t val, MatcherState &state) const noexcept {
27672 return v == val;
27673 }
27674
27675 template<uint32_t bound>
27676 HALIDE_ALWAYS_INLINE bool match(const IntLiteral &b, MatcherState &state) const noexcept {
27677 return v == b.v;
27678 }
27679
27680 HALIDE_ALWAYS_INLINE
27681 Expr make(MatcherState &state, halide_type_t type_hint) const {
27682 return make_const(type_hint, v);
27683 }
27684
27685 constexpr static bool foldable = true;
27686
27687 HALIDE_ALWAYS_INLINE
27688 void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept {
27689 // Assume type is already correct
27690 switch (ty.code) {
27691 case halide_type_int:
27692 val.u.i64 = v;
27693 break;
27694 case halide_type_uint:
27695 val.u.u64 = (uint64_t)v;
27696 break;
27697 case halide_type_float:
27698 case halide_type_bfloat:
27699 val.u.f64 = (double)v;
27700 break;
27701 default:
27702 // Unreachable
27703 ;
27704 }
27705 }
27706};
27707
27708HALIDE_ALWAYS_INLINE int64_t unwrap(IntLiteral t) {
27709 return t.v;
27710}
27711
27712// Convert a provided pattern, expr, or constant int into the internal
27713// representation we use in the matcher trees.
27714template<typename T,
27715 typename = typename std::decay<T>::type::pattern_tag>
27716HALIDE_ALWAYS_INLINE T pattern_arg(T t) {
27717 return t;
27718}
27719HALIDE_ALWAYS_INLINE
27720IntLiteral pattern_arg(int64_t x) {
27721 return IntLiteral{x};
27722}
27723
27724template<typename T>
27725HALIDE_ALWAYS_INLINE void assert_is_lvalue_if_expr() {
27726 static_assert(!std::is_same<typename std::decay<T>::type, Expr>::value || std::is_lvalue_reference<T>::value,
27727 "Exprs are captured by reference by IRMatcher objects and so must be lvalues");
27728}
27729
27730HALIDE_ALWAYS_INLINE SpecificExpr pattern_arg(const Expr &e) {
27731 return {*e.get()};
27732}
27733
27734// Helpers to deref SpecificExprs to const BaseExprNode & rather than
27735// passing them by value anywhere (incurring lots of refcounting)
27736template<typename T,
27737 // T must be a pattern node
27738 typename = typename std::decay<T>::type::pattern_tag,
27739 // But T may not be SpecificExpr
27740 typename = typename std::enable_if<!std::is_same<typename std::decay<T>::type, SpecificExpr>::value>::type>
27741HALIDE_ALWAYS_INLINE T unwrap(T t) {
27742 return t;
27743}
27744
27745HALIDE_ALWAYS_INLINE
27746const BaseExprNode &unwrap(const SpecificExpr &e) {
27747 return e.expr;
27748}
27749
27750inline std::ostream &operator<<(std::ostream &s, const IntLiteral &op) {
27751 s << op.v;
27752 return s;
27753}
27754
27755template<typename Op>
27756int64_t constant_fold_bin_op(halide_type_t &, int64_t, int64_t) noexcept;
27757
27758template<typename Op>
27759uint64_t constant_fold_bin_op(halide_type_t &, uint64_t, uint64_t) noexcept;
27760
27761template<typename Op>
27762double constant_fold_bin_op(halide_type_t &, double, double) noexcept;
27763
27764constexpr bool commutative(IRNodeType t) {
27765 return (t == IRNodeType::Add ||
27766 t == IRNodeType::Mul ||
27767 t == IRNodeType::And ||
27768 t == IRNodeType::Or ||
27769 t == IRNodeType::Min ||
27770 t == IRNodeType::Max ||
27771 t == IRNodeType::EQ ||
27772 t == IRNodeType::NE);
27773}
27774
27775// Matches one of the binary operators
27776template<typename Op, typename A, typename B>
27777struct BinOp {
27778 struct pattern_tag {};
27779 A a;
27780 B b;
27781
27782 constexpr static uint32_t binds = bindings<A>::mask | bindings<B>::mask;
27783
27784 constexpr static IRNodeType min_node_type = Op::_node_type;
27785 constexpr static IRNodeType max_node_type = Op::_node_type;
27786
27787 // For commutative bin ops, we expect the weaker IR node type on
27788 // the right. That is, for the rule to be canonical it must be
27789 // possible that A is at least as strong as B.
27790 constexpr static bool canonical =
27791 A::canonical && B::canonical && (!commutative(Op::_node_type) || (A::max_node_type >= B::min_node_type));
27792
27793 template<uint32_t bound>
27794 HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
27795 if (e.node_type != Op::_node_type) {
27796 return false;
27797 }
27798 const Op &op = (const Op &)e;
27799 return (a.template match<bound>(*op.a.get(), state) &&
27800 b.template match<bound | bindings<A>::mask>(*op.b.get(), state));
27801 }
27802
27803 template<uint32_t bound, typename Op2, typename A2, typename B2>
27804 HALIDE_ALWAYS_INLINE bool match(const BinOp<Op2, A2, B2> &op, MatcherState &state) const noexcept {
27805 return (std::is_same<Op, Op2>::value &&
27806 a.template match<bound>(unwrap(op.a), state) &&
27807 b.template match<bound | bindings<A>::mask>(unwrap(op.b), state));
27808 }
27809
27810 constexpr static bool foldable = A::foldable && B::foldable;
27811
27812 HALIDE_ALWAYS_INLINE
27813 void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept {
27814 halide_scalar_value_t val_a, val_b;
27815 if (std::is_same<A, IntLiteral>::value) {
27816 b.make_folded_const(val_b, ty, state);
27817 if ((std::is_same<Op, And>::value && val_b.u.u64 == 0) ||
27818 (std::is_same<Op, Or>::value && val_b.u.u64 == 1)) {
27819 // Short circuit
27820 val = val_b;
27821 return;
27822 }
27823 const uint16_t l = ty.lanes;
27824 a.make_folded_const(val_a, ty, state);
27825 ty.lanes |= l; // Make sure the overflow bits are sticky
27826 } else {
27827 a.make_folded_const(val_a, ty, state);
27828 if ((std::is_same<Op, And>::value && val_a.u.u64 == 0) ||
27829 (std::is_same<Op, Or>::value && val_a.u.u64 == 1)) {
27830 // Short circuit
27831 val = val_a;
27832 return;
27833 }
27834 const uint16_t l = ty.lanes;
27835 b.make_folded_const(val_b, ty, state);
27836 ty.lanes |= l;
27837 }
27838 switch (ty.code) {
27839 case halide_type_int:
27840 val.u.i64 = constant_fold_bin_op<Op>(ty, val_a.u.i64, val_b.u.i64);
27841 break;
27842 case halide_type_uint:
27843 val.u.u64 = constant_fold_bin_op<Op>(ty, val_a.u.u64, val_b.u.u64);
27844 break;
27845 case halide_type_float:
27846 case halide_type_bfloat:
27847 val.u.f64 = constant_fold_bin_op<Op>(ty, val_a.u.f64, val_b.u.f64);
27848 break;
27849 default:
27850 // unreachable
27851 ;
27852 }
27853 }
27854
27855 HALIDE_ALWAYS_INLINE
27856 Expr make(MatcherState &state, halide_type_t type_hint) const noexcept {
27857 Expr ea, eb;
27858 if (std::is_same<A, IntLiteral>::value) {
27859 eb = b.make(state, type_hint);
27860 ea = a.make(state, eb.type());
27861 } else {
27862 ea = a.make(state, type_hint);
27863 eb = b.make(state, ea.type());
27864 }
27865 // We sometimes mix vectors and scalars in the rewrite rules,
27866 // so insert a broadcast if necessary.
27867 if (ea.type().is_vector() && !eb.type().is_vector()) {
27868 eb = Broadcast::make(eb, ea.type().lanes());
27869 }
27870 if (eb.type().is_vector() && !ea.type().is_vector()) {
27871 ea = Broadcast::make(ea, eb.type().lanes());
27872 }
27873 return Op::make(std::move(ea), std::move(eb));
27874 }
27875};
27876
27877template<typename Op>
27878uint64_t constant_fold_cmp_op(int64_t, int64_t) noexcept;
27879
27880template<typename Op>
27881uint64_t constant_fold_cmp_op(uint64_t, uint64_t) noexcept;
27882
27883template<typename Op>
27884uint64_t constant_fold_cmp_op(double, double) noexcept;
27885
27886// Matches one of the comparison operators
27887template<typename Op, typename A, typename B>
27888struct CmpOp {
27889 struct pattern_tag {};
27890 A a;
27891 B b;
27892
27893 constexpr static uint32_t binds = bindings<A>::mask | bindings<B>::mask;
27894
27895 constexpr static IRNodeType min_node_type = Op::_node_type;
27896 constexpr static IRNodeType max_node_type = Op::_node_type;
27897 constexpr static bool canonical = (A::canonical &&
27898 B::canonical &&
27899 (!commutative(Op::_node_type) || A::max_node_type >= B::min_node_type) &&
27900 (Op::_node_type != IRNodeType::GE) &&
27901 (Op::_node_type != IRNodeType::GT));
27902
27903 template<uint32_t bound>
27904 HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
27905 if (e.node_type != Op::_node_type) {
27906 return false;
27907 }
27908 const Op &op = (const Op &)e;
27909 return (a.template match<bound>(*op.a.get(), state) &&
27910 b.template match<bound | bindings<A>::mask>(*op.b.get(), state));
27911 }
27912
27913 template<uint32_t bound, typename Op2, typename A2, typename B2>
27914 HALIDE_ALWAYS_INLINE bool match(const CmpOp<Op2, A2, B2> &op, MatcherState &state) const noexcept {
27915 return (std::is_same<Op, Op2>::value &&
27916 a.template match<bound>(unwrap(op.a), state) &&
27917 b.template match<bound | bindings<A>::mask>(unwrap(op.b), state));
27918 }
27919
27920 constexpr static bool foldable = A::foldable && B::foldable;
27921
27922 HALIDE_ALWAYS_INLINE
27923 void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept {
27924 halide_scalar_value_t val_a, val_b;
27925 // If one side is an untyped const, evaluate the other side first to get a type hint.
27926 if (std::is_same<A, IntLiteral>::value) {
27927 b.make_folded_const(val_b, ty, state);
27928 const uint16_t l = ty.lanes;
27929 a.make_folded_const(val_a, ty, state);
27930 ty.lanes |= l;
27931 } else {
27932 a.make_folded_const(val_a, ty, state);
27933 const uint16_t l = ty.lanes;
27934 b.make_folded_const(val_b, ty, state);
27935 ty.lanes |= l;
27936 }
27937 switch (ty.code) {
27938 case halide_type_int:
27939 val.u.u64 = constant_fold_cmp_op<Op>(val_a.u.i64, val_b.u.i64);
27940 break;
27941 case halide_type_uint:
27942 val.u.u64 = constant_fold_cmp_op<Op>(val_a.u.u64, val_b.u.u64);
27943 break;
27944 case halide_type_float:
27945 case halide_type_bfloat:
27946 val.u.u64 = constant_fold_cmp_op<Op>(val_a.u.f64, val_b.u.f64);
27947 break;
27948 default:
27949 // unreachable
27950 ;
27951 }
27952 ty.code = halide_type_uint;
27953 ty.bits = 1;
27954 }
27955
27956 HALIDE_ALWAYS_INLINE
27957 Expr make(MatcherState &state, halide_type_t type_hint) const {
27958 // If one side is an untyped const, evaluate the other side first to get a type hint.
27959 Expr ea, eb;
27960 if (std::is_same<A, IntLiteral>::value) {
27961 eb = b.make(state, {});
27962 ea = a.make(state, eb.type());
27963 } else {
27964 ea = a.make(state, {});
27965 eb = b.make(state, ea.type());
27966 }
27967 // We sometimes mix vectors and scalars in the rewrite rules,
27968 // so insert a broadcast if necessary.
27969 if (ea.type().is_vector() && !eb.type().is_vector()) {
27970 eb = Broadcast::make(eb, ea.type().lanes());
27971 }
27972 if (eb.type().is_vector() && !ea.type().is_vector()) {
27973 ea = Broadcast::make(ea, eb.type().lanes());
27974 }
27975 return Op::make(std::move(ea), std::move(eb));
27976 }
27977};
27978
27979template<typename A, typename B>
27980std::ostream &operator<<(std::ostream &s, const BinOp<Add, A, B> &op) {
27981 s << "(" << op.a << " + " << op.b << ")";
27982 return s;
27983}
27984
27985template<typename A, typename B>
27986std::ostream &operator<<(std::ostream &s, const BinOp<Sub, A, B> &op) {
27987 s << "(" << op.a << " - " << op.b << ")";
27988 return s;
27989}
27990
27991template<typename A, typename B>
27992std::ostream &operator<<(std::ostream &s, const BinOp<Mul, A, B> &op) {
27993 s << "(" << op.a << " * " << op.b << ")";
27994 return s;
27995}
27996
27997template<typename A, typename B>
27998std::ostream &operator<<(std::ostream &s, const BinOp<Div, A, B> &op) {
27999 s << "(" << op.a << " / " << op.b << ")";
28000 return s;
28001}
28002
28003template<typename A, typename B>
28004std::ostream &operator<<(std::ostream &s, const BinOp<And, A, B> &op) {
28005 s << "(" << op.a << " && " << op.b << ")";
28006 return s;
28007}
28008
28009template<typename A, typename B>
28010std::ostream &operator<<(std::ostream &s, const BinOp<Or, A, B> &op) {
28011 s << "(" << op.a << " || " << op.b << ")";
28012 return s;
28013}
28014
28015template<typename A, typename B>
28016std::ostream &operator<<(std::ostream &s, const BinOp<Min, A, B> &op) {
28017 s << "min(" << op.a << ", " << op.b << ")";
28018 return s;
28019}
28020
28021template<typename A, typename B>
28022std::ostream &operator<<(std::ostream &s, const BinOp<Max, A, B> &op) {
28023 s << "max(" << op.a << ", " << op.b << ")";
28024 return s;
28025}
28026
28027template<typename A, typename B>
28028std::ostream &operator<<(std::ostream &s, const CmpOp<LE, A, B> &op) {
28029 s << "(" << op.a << " <= " << op.b << ")";
28030 return s;
28031}
28032
28033template<typename A, typename B>
28034std::ostream &operator<<(std::ostream &s, const CmpOp<LT, A, B> &op) {
28035 s << "(" << op.a << " < " << op.b << ")";
28036 return s;
28037}
28038
28039template<typename A, typename B>
28040std::ostream &operator<<(std::ostream &s, const CmpOp<GE, A, B> &op) {
28041 s << "(" << op.a << " >= " << op.b << ")";
28042 return s;
28043}
28044
28045template<typename A, typename B>
28046std::ostream &operator<<(std::ostream &s, const CmpOp<GT, A, B> &op) {
28047 s << "(" << op.a << " > " << op.b << ")";
28048 return s;
28049}
28050
28051template<typename A, typename B>
28052std::ostream &operator<<(std::ostream &s, const CmpOp<EQ, A, B> &op) {
28053 s << "(" << op.a << " == " << op.b << ")";
28054 return s;
28055}
28056
28057template<typename A, typename B>
28058std::ostream &operator<<(std::ostream &s, const CmpOp<NE, A, B> &op) {
28059 s << "(" << op.a << " != " << op.b << ")";
28060 return s;
28061}
28062
28063template<typename A, typename B>
28064std::ostream &operator<<(std::ostream &s, const BinOp<Mod, A, B> &op) {
28065 s << "(" << op.a << " % " << op.b << ")";
28066 return s;
28067}
28068
28069template<typename A, typename B>
28070HALIDE_ALWAYS_INLINE auto operator+(A &&a, B &&b) noexcept -> BinOp<Add, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
28071 assert_is_lvalue_if_expr<A>();
28072 assert_is_lvalue_if_expr<B>();
28073 return {pattern_arg(a), pattern_arg(b)};
28074}
28075
28076template<typename A, typename B>
28077HALIDE_ALWAYS_INLINE auto add(A &&a, B &&b) -> decltype(IRMatcher::operator+(a, b)) {
28078 assert_is_lvalue_if_expr<A>();
28079 assert_is_lvalue_if_expr<B>();
28080 return IRMatcher::operator+(a, b);
28081}
28082
28083template<>
28084HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op<Add>(halide_type_t &t, int64_t a, int64_t b) noexcept {
28085 t.lanes |= ((t.bits >= 32) && add_would_overflow(t.bits, a, b)) ? MatcherState::signed_integer_overflow : 0;
28086 int dead_bits = 64 - t.bits;
28087 // Drop the high bits then sign-extend them back
28088 return int64_t((uint64_t(a) + uint64_t(b)) << dead_bits) >> dead_bits;
28089}
28090
28091template<>
28092HALIDE_ALWAYS_INLINE uint64_t constant_fold_bin_op<Add>(halide_type_t &t, uint64_t a, uint64_t b) noexcept {
28093 uint64_t ones = (uint64_t)(-1);
28094 return (a + b) & (ones >> (64 - t.bits));
28095}
28096
28097template<>
28098HALIDE_ALWAYS_INLINE double constant_fold_bin_op<Add>(halide_type_t &t, double a, double b) noexcept {
28099 return a + b;
28100}
28101
28102template<typename A, typename B>
28103HALIDE_ALWAYS_INLINE auto operator-(A &&a, B &&b) noexcept -> BinOp<Sub, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
28104 assert_is_lvalue_if_expr<A>();
28105 assert_is_lvalue_if_expr<B>();
28106 return {pattern_arg(a), pattern_arg(b)};
28107}
28108
28109template<typename A, typename B>
28110HALIDE_ALWAYS_INLINE auto sub(A &&a, B &&b) -> decltype(IRMatcher::operator-(a, b)) {
28111 assert_is_lvalue_if_expr<A>();
28112 assert_is_lvalue_if_expr<B>();
28113 return IRMatcher::operator-(a, b);
28114}
28115
28116template<>
28117HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op<Sub>(halide_type_t &t, int64_t a, int64_t b) noexcept {
28118 t.lanes |= ((t.bits >= 32) && sub_would_overflow(t.bits, a, b)) ? MatcherState::signed_integer_overflow : 0;
28119 // Drop the high bits then sign-extend them back
28120 int dead_bits = 64 - t.bits;
28121 return int64_t((uint64_t(a) - uint64_t(b)) << dead_bits) >> dead_bits;
28122}
28123
28124template<>
28125HALIDE_ALWAYS_INLINE uint64_t constant_fold_bin_op<Sub>(halide_type_t &t, uint64_t a, uint64_t b) noexcept {
28126 uint64_t ones = (uint64_t)(-1);
28127 return (a - b) & (ones >> (64 - t.bits));
28128}
28129
28130template<>
28131HALIDE_ALWAYS_INLINE double constant_fold_bin_op<Sub>(halide_type_t &t, double a, double b) noexcept {
28132 return a - b;
28133}
28134
28135template<typename A, typename B>
28136HALIDE_ALWAYS_INLINE auto operator*(A &&a, B &&b) noexcept -> BinOp<Mul, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
28137 assert_is_lvalue_if_expr<A>();
28138 assert_is_lvalue_if_expr<B>();
28139 return {pattern_arg(a), pattern_arg(b)};
28140}
28141
28142template<typename A, typename B>
28143HALIDE_ALWAYS_INLINE auto mul(A &&a, B &&b) -> decltype(IRMatcher::operator*(a, b)) {
28144 assert_is_lvalue_if_expr<A>();
28145 assert_is_lvalue_if_expr<B>();
28146 return IRMatcher::operator*(a, b);
28147}
28148
28149template<>
28150HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op<Mul>(halide_type_t &t, int64_t a, int64_t b) noexcept {
28151 t.lanes |= ((t.bits >= 32) && mul_would_overflow(t.bits, a, b)) ? MatcherState::signed_integer_overflow : 0;
28152 int dead_bits = 64 - t.bits;
28153 // Drop the high bits then sign-extend them back
28154 return int64_t((uint64_t(a) * uint64_t(b)) << dead_bits) >> dead_bits;
28155}
28156
28157template<>
28158HALIDE_ALWAYS_INLINE uint64_t constant_fold_bin_op<Mul>(halide_type_t &t, uint64_t a, uint64_t b) noexcept {
28159 uint64_t ones = (uint64_t)(-1);
28160 return (a * b) & (ones >> (64 - t.bits));
28161}
28162
28163template<>
28164HALIDE_ALWAYS_INLINE double constant_fold_bin_op<Mul>(halide_type_t &t, double a, double b) noexcept {
28165 return a * b;
28166}
28167
28168template<typename A, typename B>
28169HALIDE_ALWAYS_INLINE auto operator/(A &&a, B &&b) noexcept -> BinOp<Div, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
28170 assert_is_lvalue_if_expr<A>();
28171 assert_is_lvalue_if_expr<B>();
28172 return {pattern_arg(a), pattern_arg(b)};
28173}
28174
28175template<typename A, typename B>
28176HALIDE_ALWAYS_INLINE auto div(A &&a, B &&b) -> decltype(IRMatcher::operator/(a, b)) {
28177 return IRMatcher::operator/(a, b);
28178}
28179
28180template<>
28181HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op<Div>(halide_type_t &t, int64_t a, int64_t b) noexcept {
28182 return div_imp(a, b);
28183}
28184
28185template<>
28186HALIDE_ALWAYS_INLINE uint64_t constant_fold_bin_op<Div>(halide_type_t &t, uint64_t a, uint64_t b) noexcept {
28187 return div_imp(a, b);
28188}
28189
28190template<>
28191HALIDE_ALWAYS_INLINE double constant_fold_bin_op<Div>(halide_type_t &t, double a, double b) noexcept {
28192 return div_imp(a, b);
28193}
28194
28195template<typename A, typename B>
28196HALIDE_ALWAYS_INLINE auto operator%(A &&a, B &&b) noexcept -> BinOp<Mod, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
28197 assert_is_lvalue_if_expr<A>();
28198 assert_is_lvalue_if_expr<B>();
28199 return {pattern_arg(a), pattern_arg(b)};
28200}
28201
28202template<typename A, typename B>
28203HALIDE_ALWAYS_INLINE auto mod(A &&a, B &&b) -> decltype(IRMatcher::operator%(a, b)) {
28204 assert_is_lvalue_if_expr<A>();
28205 assert_is_lvalue_if_expr<B>();
28206 return IRMatcher::operator%(a, b);
28207}
28208
28209template<>
28210HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op<Mod>(halide_type_t &t, int64_t a, int64_t b) noexcept {
28211 return mod_imp(a, b);
28212}
28213
28214template<>
28215HALIDE_ALWAYS_INLINE uint64_t constant_fold_bin_op<Mod>(halide_type_t &t, uint64_t a, uint64_t b) noexcept {
28216 return mod_imp(a, b);
28217}
28218
28219template<>
28220HALIDE_ALWAYS_INLINE double constant_fold_bin_op<Mod>(halide_type_t &t, double a, double b) noexcept {
28221 return mod_imp(a, b);
28222}
28223
28224template<typename A, typename B>
28225HALIDE_ALWAYS_INLINE auto min(A &&a, B &&b) noexcept -> BinOp<Min, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
28226 assert_is_lvalue_if_expr<A>();
28227 assert_is_lvalue_if_expr<B>();
28228 return {pattern_arg(a), pattern_arg(b)};
28229}
28230
28231template<>
28232HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op<Min>(halide_type_t &t, int64_t a, int64_t b) noexcept {
28233 return std::min(a, b);
28234}
28235
28236template<>
28237HALIDE_ALWAYS_INLINE uint64_t constant_fold_bin_op<Min>(halide_type_t &t, uint64_t a, uint64_t b) noexcept {
28238 return std::min(a, b);
28239}
28240
28241template<>
28242HALIDE_ALWAYS_INLINE double constant_fold_bin_op<Min>(halide_type_t &t, double a, double b) noexcept {
28243 return std::min(a, b);
28244}
28245
28246template<typename A, typename B>
28247HALIDE_ALWAYS_INLINE auto max(A &&a, B &&b) noexcept -> BinOp<Max, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
28248 assert_is_lvalue_if_expr<A>();
28249 assert_is_lvalue_if_expr<B>();
28250 return {pattern_arg(std::forward<A>(a)), pattern_arg(std::forward<B>(b))};
28251}
28252
28253template<>
28254HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op<Max>(halide_type_t &t, int64_t a, int64_t b) noexcept {
28255 return std::max(a, b);
28256}
28257
28258template<>
28259HALIDE_ALWAYS_INLINE uint64_t constant_fold_bin_op<Max>(halide_type_t &t, uint64_t a, uint64_t b) noexcept {
28260 return std::max(a, b);
28261}
28262
28263template<>
28264HALIDE_ALWAYS_INLINE double constant_fold_bin_op<Max>(halide_type_t &t, double a, double b) noexcept {
28265 return std::max(a, b);
28266}
28267
28268template<typename A, typename B>
28269HALIDE_ALWAYS_INLINE auto operator<(A &&a, B &&b) noexcept -> CmpOp<LT, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
28270 return {pattern_arg(a), pattern_arg(b)};
28271}
28272
28273template<typename A, typename B>
28274HALIDE_ALWAYS_INLINE auto lt(A &&a, B &&b) -> decltype(IRMatcher::operator<(a, b)) {
28275 return IRMatcher::operator<(a, b);
28276}
28277
28278template<>
28279HALIDE_ALWAYS_INLINE uint64_t constant_fold_cmp_op<LT>(int64_t a, int64_t b) noexcept {
28280 return a < b;
28281}
28282
28283template<>
28284HALIDE_ALWAYS_INLINE uint64_t constant_fold_cmp_op<LT>(uint64_t a, uint64_t b) noexcept {
28285 return a < b;
28286}
28287
28288template<>
28289HALIDE_ALWAYS_INLINE uint64_t constant_fold_cmp_op<LT>(double a, double b) noexcept {
28290 return a < b;
28291}
28292
28293template<typename A, typename B>
28294HALIDE_ALWAYS_INLINE auto operator>(A &&a, B &&b) noexcept -> CmpOp<GT, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
28295 return {pattern_arg(a), pattern_arg(b)};
28296}
28297
28298template<typename A, typename B>
28299HALIDE_ALWAYS_INLINE auto gt(A &&a, B &&b) -> decltype(IRMatcher::operator>(a, b)) {
28300 return IRMatcher::operator>(a, b);
28301}
28302
28303template<>
28304HALIDE_ALWAYS_INLINE uint64_t constant_fold_cmp_op<GT>(int64_t a, int64_t b) noexcept {
28305 return a > b;
28306}
28307
28308template<>
28309HALIDE_ALWAYS_INLINE uint64_t constant_fold_cmp_op<GT>(uint64_t a, uint64_t b) noexcept {
28310 return a > b;
28311}
28312
28313template<>
28314HALIDE_ALWAYS_INLINE uint64_t constant_fold_cmp_op<GT>(double a, double b) noexcept {
28315 return a > b;
28316}
28317
28318template<typename A, typename B>
28319HALIDE_ALWAYS_INLINE auto operator<=(A &&a, B &&b) noexcept -> CmpOp<LE, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
28320 return {pattern_arg(a), pattern_arg(b)};
28321}
28322
28323template<typename A, typename B>
28324HALIDE_ALWAYS_INLINE auto le(A &&a, B &&b) -> decltype(IRMatcher::operator<=(a, b)) {
28325 return IRMatcher::operator<=(a, b);
28326}
28327
28328template<>
28329HALIDE_ALWAYS_INLINE uint64_t constant_fold_cmp_op<LE>(int64_t a, int64_t b) noexcept {
28330 return a <= b;
28331}
28332
28333template<>
28334HALIDE_ALWAYS_INLINE uint64_t constant_fold_cmp_op<LE>(uint64_t a, uint64_t b) noexcept {
28335 return a <= b;
28336}
28337
28338template<>
28339HALIDE_ALWAYS_INLINE uint64_t constant_fold_cmp_op<LE>(double a, double b) noexcept {
28340 return a <= b;
28341}
28342
28343template<typename A, typename B>
28344HALIDE_ALWAYS_INLINE auto operator>=(A &&a, B &&b) noexcept -> CmpOp<GE, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
28345 return {pattern_arg(a), pattern_arg(b)};
28346}
28347
28348template<typename A, typename B>
28349HALIDE_ALWAYS_INLINE auto ge(A &&a, B &&b) -> decltype(IRMatcher::operator>=(a, b)) {
28350 return IRMatcher::operator>=(a, b);
28351}
28352
28353template<>
28354HALIDE_ALWAYS_INLINE uint64_t constant_fold_cmp_op<GE>(int64_t a, int64_t b) noexcept {
28355 return a >= b;
28356}
28357
28358template<>
28359HALIDE_ALWAYS_INLINE uint64_t constant_fold_cmp_op<GE>(uint64_t a, uint64_t b) noexcept {
28360 return a >= b;
28361}
28362
28363template<>
28364HALIDE_ALWAYS_INLINE uint64_t constant_fold_cmp_op<GE>(double a, double b) noexcept {
28365 return a >= b;
28366}
28367
28368template<typename A, typename B>
28369HALIDE_ALWAYS_INLINE auto operator==(A &&a, B &&b) noexcept -> CmpOp<EQ, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
28370 return {pattern_arg(a), pattern_arg(b)};
28371}
28372
28373template<typename A, typename B>
28374HALIDE_ALWAYS_INLINE auto eq(A &&a, B &&b) -> decltype(IRMatcher::operator==(a, b)) {
28375 return IRMatcher::operator==(a, b);
28376}
28377
28378template<>
28379HALIDE_ALWAYS_INLINE uint64_t constant_fold_cmp_op<EQ>(int64_t a, int64_t b) noexcept {
28380 return a == b;
28381}
28382
28383template<>
28384HALIDE_ALWAYS_INLINE uint64_t constant_fold_cmp_op<EQ>(uint64_t a, uint64_t b) noexcept {
28385 return a == b;
28386}
28387
28388template<>
28389HALIDE_ALWAYS_INLINE uint64_t constant_fold_cmp_op<EQ>(double a, double b) noexcept {
28390 return a == b;
28391}
28392
28393template<typename A, typename B>
28394HALIDE_ALWAYS_INLINE auto operator!=(A &&a, B &&b) noexcept -> CmpOp<NE, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
28395 return {pattern_arg(a), pattern_arg(b)};
28396}
28397
28398template<typename A, typename B>
28399HALIDE_ALWAYS_INLINE auto ne(A &&a, B &&b) -> decltype(IRMatcher::operator!=(a, b)) {
28400 return IRMatcher::operator!=(a, b);
28401}
28402
28403template<>
28404HALIDE_ALWAYS_INLINE uint64_t constant_fold_cmp_op<NE>(int64_t a, int64_t b) noexcept {
28405 return a != b;
28406}
28407
28408template<>
28409HALIDE_ALWAYS_INLINE uint64_t constant_fold_cmp_op<NE>(uint64_t a, uint64_t b) noexcept {
28410 return a != b;
28411}
28412
28413template<>
28414HALIDE_ALWAYS_INLINE uint64_t constant_fold_cmp_op<NE>(double a, double b) noexcept {
28415 return a != b;
28416}
28417
28418template<typename A, typename B>
28419HALIDE_ALWAYS_INLINE auto operator||(A &&a, B &&b) noexcept -> BinOp<Or, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
28420 return {pattern_arg(a), pattern_arg(b)};
28421}
28422
28423template<typename A, typename B>
28424HALIDE_ALWAYS_INLINE auto or_op(A &&a, B &&b) -> decltype(IRMatcher::operator||(a, b)) {
28425 return IRMatcher::operator||(a, b);
28426}
28427
28428template<>
28429HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op<Or>(halide_type_t &t, int64_t a, int64_t b) noexcept {
28430 return (a | b) & 1;
28431}
28432
28433template<>
28434HALIDE_ALWAYS_INLINE uint64_t constant_fold_bin_op<Or>(halide_type_t &t, uint64_t a, uint64_t b) noexcept {
28435 return (a | b) & 1;
28436}
28437
28438template<>
28439HALIDE_ALWAYS_INLINE double constant_fold_bin_op<Or>(halide_type_t &t, double a, double b) noexcept {
28440 // Unreachable, as it would be a type mismatch.
28441 return 0;
28442}
28443
28444template<typename A, typename B>
28445HALIDE_ALWAYS_INLINE auto operator&&(A &&a, B &&b) noexcept -> BinOp<And, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
28446 return {pattern_arg(a), pattern_arg(b)};
28447}
28448
28449template<typename A, typename B>
28450HALIDE_ALWAYS_INLINE auto and_op(A &&a, B &&b) -> decltype(IRMatcher::operator&&(a, b)) {
28451 return IRMatcher::operator&&(a, b);
28452}
28453
28454template<>
28455HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op<And>(halide_type_t &t, int64_t a, int64_t b) noexcept {
28456 return a & b & 1;
28457}
28458
28459template<>
28460HALIDE_ALWAYS_INLINE uint64_t constant_fold_bin_op<And>(halide_type_t &t, uint64_t a, uint64_t b) noexcept {
28461 return a & b & 1;
28462}
28463
28464template<>
28465HALIDE_ALWAYS_INLINE double constant_fold_bin_op<And>(halide_type_t &t, double a, double b) noexcept {
28466 // Unreachable
28467 return 0;
28468}
28469
28470constexpr inline uint32_t bitwise_or_reduce() {
28471 return 0;
28472}
28473
28474template<typename... Args>
28475constexpr uint32_t bitwise_or_reduce(uint32_t first, Args... rest) {
28476 return first | bitwise_or_reduce(rest...);
28477}
28478
28479constexpr inline bool and_reduce() {
28480 return true;
28481}
28482
28483template<typename... Args>
28484constexpr bool and_reduce(bool first, Args... rest) {
28485 return first && and_reduce(rest...);
28486}
28487
28488// TODO: this can be replaced with std::min() once we require C++14 or later
28489constexpr int const_min(int a, int b) {
28490 return a < b ? a : b;
28491}
28492
28493template<typename... Args>
28494struct Intrin {
28495 struct pattern_tag {};
28496 Call::IntrinsicOp intrin;
28497 std::tuple<Args...> args;
28498
28499 static constexpr uint32_t binds = bitwise_or_reduce((bindings<Args>::mask)...);
28500
28501 constexpr static IRNodeType min_node_type = IRNodeType::Call;
28502 constexpr static IRNodeType max_node_type = IRNodeType::Call;
28503 constexpr static bool canonical = and_reduce((Args::canonical)...);
28504
28505 template<int i,
28506 uint32_t bound,
28507 typename = typename std::enable_if<(i < sizeof...(Args))>::type>
28508 HALIDE_ALWAYS_INLINE bool match_args(int, const Call &c, MatcherState &state) const noexcept {
28509 using T = decltype(std::get<i>(args));
28510 return (std::get<i>(args).template match<bound>(*c.args[i].get(), state) &&
28511 match_args<i + 1, bound | bindings<T>::mask>(0, c, state));
28512 }
28513
28514 template<int i, uint32_t binds>
28515 HALIDE_ALWAYS_INLINE bool match_args(double, const Call &c, MatcherState &state) const noexcept {
28516 return true;
28517 }
28518
28519 template<uint32_t bound>
28520 HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
28521 if (e.node_type != IRNodeType::Call) {
28522 return false;
28523 }
28524 const Call &c = (const Call &)e;
28525 return (c.is_intrinsic(intrin) && match_args<0, bound>(0, c, state));
28526 }
28527
28528 template<int i,
28529 typename = typename std::enable_if<(i < sizeof...(Args))>::type>
28530 HALIDE_ALWAYS_INLINE void print_args(int, std::ostream &s) const {
28531 s << std::get<i>(args);
28532 if (i + 1 < sizeof...(Args)) {
28533 s << ", ";
28534 }
28535 print_args<i + 1>(0, s);
28536 }
28537
28538 template<int i>
28539 HALIDE_ALWAYS_INLINE void print_args(double, std::ostream &s) const {
28540 }
28541
28542 HALIDE_ALWAYS_INLINE
28543 void print_args(std::ostream &s) const {
28544 print_args<0>(0, s);
28545 }
28546
28547 HALIDE_ALWAYS_INLINE
28548 Expr make(MatcherState &state, halide_type_t type_hint) const {
28549 Expr arg0 = std::get<0>(args).make(state, type_hint);
28550 if (intrin == Call::likely) {
28551 return likely(arg0);
28552 } else if (intrin == Call::likely_if_innermost) {
28553 return likely_if_innermost(arg0);
28554 } else if (intrin == Call::abs) {
28555 return abs(arg0);
28556 }
28557
28558 Expr arg1 = std::get<const_min(1, sizeof...(Args) - 1)>(args).make(state, type_hint);
28559 if (intrin == Call::absd) {
28560 return absd(arg0, arg1);
28561 } else if (intrin == Call::widening_add) {
28562 return widening_add(arg0, arg1);
28563 } else if (intrin == Call::widening_sub) {
28564 return widening_sub(arg0, arg1);
28565 } else if (intrin == Call::widening_mul) {
28566 return widening_mul(arg0, arg1);
28567 } else if (intrin == Call::saturating_add) {
28568 return saturating_add(arg0, arg1);
28569 } else if (intrin == Call::saturating_sub) {
28570 return saturating_sub(arg0, arg1);
28571 } else if (intrin == Call::halving_add) {
28572 return halving_add(arg0, arg1);
28573 } else if (intrin == Call::halving_sub) {
28574 return halving_sub(arg0, arg1);
28575 } else if (intrin == Call::rounding_halving_add) {
28576 return rounding_halving_add(arg0, arg1);
28577 } else if (intrin == Call::rounding_halving_sub) {
28578 return rounding_halving_sub(arg0, arg1);
28579 } else if (intrin == Call::shift_left) {
28580 return arg0 << arg1;
28581 } else if (intrin == Call::shift_right) {
28582 return arg0 >> arg1;
28583 } else if (intrin == Call::rounding_shift_left) {
28584 return rounding_shift_left(arg0, arg1);
28585 } else if (intrin == Call::rounding_shift_right) {
28586 return rounding_shift_right(arg0, arg1);
28587 }
28588
28589 Expr arg2 = std::get<const_min(2, sizeof...(Args) - 1)>(args).make(state, type_hint);
28590 if (intrin == Call::mul_shift_right) {
28591 return mul_shift_right(arg0, arg1, arg2);
28592 } else if (intrin == Call::rounding_mul_shift_right) {
28593 return rounding_mul_shift_right(arg0, arg1, arg2);
28594 }
28595
28596 internal_error << "Unhandled intrinsic in IRMatcher: " << intrin;
28597 return Expr();
28598 }
28599
28600 constexpr static bool foldable = false;
28601
28602 HALIDE_ALWAYS_INLINE
28603 Intrin(Call::IntrinsicOp intrin, Args... args) noexcept
28604 : intrin(intrin), args(args...) {
28605 }
28606};
28607
28608template<typename... Args>
28609std::ostream &operator<<(std::ostream &s, const Intrin<Args...> &op) {
28610 s << op.intrin << "(";
28611 op.print_args(s);
28612 s << ")";
28613 return s;
28614}
28615
28616template<typename... Args>
28617HALIDE_ALWAYS_INLINE auto intrin(Call::IntrinsicOp intrinsic_op, Args... args) noexcept -> Intrin<decltype(pattern_arg(args))...> {
28618 return {intrinsic_op, pattern_arg(args)...};
28619}
28620
28621template<typename A, typename B>
28622auto widening_add(A &&a, B &&b) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
28623 return {Call::widening_add, pattern_arg(a), pattern_arg(b)};
28624}
28625template<typename A, typename B>
28626auto widening_sub(A &&a, B &&b) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
28627 return {Call::widening_sub, pattern_arg(a), pattern_arg(b)};
28628}
28629template<typename A, typename B>
28630auto widening_mul(A &&a, B &&b) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
28631 return {Call::widening_mul, pattern_arg(a), pattern_arg(b)};
28632}
28633template<typename A, typename B>
28634auto saturating_add(A &&a, B &&b) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
28635 return {Call::saturating_add, pattern_arg(a), pattern_arg(b)};
28636}
28637template<typename A, typename B>
28638auto saturating_sub(A &&a, B &&b) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
28639 return {Call::saturating_sub, pattern_arg(a), pattern_arg(b)};
28640}
28641template<typename A, typename B>
28642auto halving_add(A &&a, B &&b) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
28643 return {Call::halving_add, pattern_arg(a), pattern_arg(b)};
28644}
28645template<typename A, typename B>
28646auto halving_sub(A &&a, B &&b) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
28647 return {Call::halving_sub, pattern_arg(a), pattern_arg(b)};
28648}
28649template<typename A, typename B>
28650auto rounding_halving_add(A &&a, B &&b) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
28651 return {Call::rounding_halving_add, pattern_arg(a), pattern_arg(b)};
28652}
28653template<typename A, typename B>
28654auto rounding_halving_sub(A &&a, B &&b) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
28655 return {Call::rounding_halving_sub, pattern_arg(a), pattern_arg(b)};
28656}
28657template<typename A, typename B>
28658auto shift_left(A &&a, B &&b) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
28659 return {Call::shift_left, pattern_arg(a), pattern_arg(b)};
28660}
28661template<typename A, typename B>
28662auto shift_right(A &&a, B &&b) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
28663 return {Call::shift_right, pattern_arg(a), pattern_arg(b)};
28664}
28665template<typename A, typename B>
28666auto rounding_shift_left(A &&a, B &&b) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
28667 return {Call::rounding_shift_left, pattern_arg(a), pattern_arg(b)};
28668}
28669template<typename A, typename B>
28670auto rounding_shift_right(A &&a, B &&b) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
28671 return {Call::rounding_shift_right, pattern_arg(a), pattern_arg(b)};
28672}
28673template<typename A, typename B, typename C>
28674auto mul_shift_right(A &&a, B &&b, C &&c) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b)), decltype(pattern_arg(c))> {
28675 return {Call::mul_shift_right, pattern_arg(a), pattern_arg(b), pattern_arg(c)};
28676}
28677template<typename A, typename B, typename C>
28678auto rounding_mul_shift_right(A &&a, B &&b, C &&c) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b)), decltype(pattern_arg(c))> {
28679 return {Call::rounding_mul_shift_right, pattern_arg(a), pattern_arg(b), pattern_arg(c)};
28680}
28681
28682template<typename A>
28683struct NotOp {
28684 struct pattern_tag {};
28685 A a;
28686
28687 constexpr static uint32_t binds = bindings<A>::mask;
28688
28689 constexpr static IRNodeType min_node_type = IRNodeType::Not;
28690 constexpr static IRNodeType max_node_type = IRNodeType::Not;
28691 constexpr static bool canonical = A::canonical;
28692
28693 template<uint32_t bound>
28694 HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
28695 if (e.node_type != IRNodeType::Not) {
28696 return false;
28697 }
28698 const Not &op = (const Not &)e;
28699 return (a.template match<bound>(*op.a.get(), state));
28700 }
28701
28702 template<uint32_t bound, typename A2>
28703 HALIDE_ALWAYS_INLINE bool match(const NotOp<A2> &op, MatcherState &state) const noexcept {
28704 return a.template match<bound>(unwrap(op.a), state);
28705 }
28706
28707 HALIDE_ALWAYS_INLINE
28708 Expr make(MatcherState &state, halide_type_t type_hint) const {
28709 return Not::make(a.make(state, type_hint));
28710 }
28711
28712 constexpr static bool foldable = A::foldable;
28713
28714 template<typename A1 = A>
28715 HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept {
28716 a.make_folded_const(val, ty, state);
28717 val.u.u64 = ~val.u.u64;
28718 val.u.u64 &= 1;
28719 }
28720};
28721
28722template<typename A>
28723HALIDE_ALWAYS_INLINE auto operator!(A &&a) noexcept -> NotOp<decltype(pattern_arg(a))> {
28724 assert_is_lvalue_if_expr<A>();
28725 return {pattern_arg(a)};
28726}
28727
28728template<typename A>
28729HALIDE_ALWAYS_INLINE auto not_op(A &&a) -> decltype(IRMatcher::operator!(a)) {
28730 assert_is_lvalue_if_expr<A>();
28731 return IRMatcher::operator!(a);
28732}
28733
28734template<typename A>
28735inline std::ostream &operator<<(std::ostream &s, const NotOp<A> &op) {
28736 s << "!(" << op.a << ")";
28737 return s;
28738}
28739
28740template<typename C, typename T, typename F>
28741struct SelectOp {
28742 struct pattern_tag {};
28743 C c;
28744 T t;
28745 F f;
28746
28747 constexpr static uint32_t binds = bindings<C>::mask | bindings<T>::mask | bindings<F>::mask;
28748
28749 constexpr static IRNodeType min_node_type = IRNodeType::Select;
28750 constexpr static IRNodeType max_node_type = IRNodeType::Select;
28751
28752 constexpr static bool canonical = C::canonical && T::canonical && F::canonical;
28753
28754 template<uint32_t bound>
28755 HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
28756 if (e.node_type != Select::_node_type) {
28757 return false;
28758 }
28759 const Select &op = (const Select &)e;
28760 return (c.template match<bound>(*op.condition.get(), state) &&
28761 t.template match<bound | bindings<C>::mask>(*op.true_value.get(), state) &&
28762 f.template match<bound | bindings<C>::mask | bindings<T>::mask>(*op.false_value.get(), state));
28763 }
28764 template<uint32_t bound, typename C2, typename T2, typename F2>
28765 HALIDE_ALWAYS_INLINE bool match(const SelectOp<C2, T2, F2> &instance, MatcherState &state) const noexcept {
28766 return (c.template match<bound>(unwrap(instance.c), state) &&
28767 t.template match<bound | bindings<C>::mask>(unwrap(instance.t), state) &&
28768 f.template match<bound | bindings<C>::mask | bindings<T>::mask>(unwrap(instance.f), state));
28769 }
28770
28771 HALIDE_ALWAYS_INLINE
28772 Expr make(MatcherState &state, halide_type_t type_hint) const {
28773 return Select::make(c.make(state, {}), t.make(state, type_hint), f.make(state, type_hint));
28774 }
28775
28776 constexpr static bool foldable = C::foldable && T::foldable && F::foldable;
28777
28778 template<typename C1 = C>
28779 HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept {
28780 halide_scalar_value_t c_val, t_val, f_val;
28781 halide_type_t c_ty;
28782 c.make_folded_const(c_val, c_ty, state);
28783 if ((c_val.u.u64 & 1) == 1) {
28784 t.make_folded_const(val, ty, state);
28785 } else {
28786 f.make_folded_const(val, ty, state);
28787 }
28788 ty.lanes |= c_ty.lanes & MatcherState::special_values_mask;
28789 }
28790};
28791
28792template<typename C, typename T, typename F>
28793std::ostream &operator<<(std::ostream &s, const SelectOp<C, T, F> &op) {
28794 s << "select(" << op.c << ", " << op.t << ", " << op.f << ")";
28795 return s;
28796}
28797
28798template<typename C, typename T, typename F>
28799HALIDE_ALWAYS_INLINE auto select(C &&c, T &&t, F &&f) noexcept -> SelectOp<decltype(pattern_arg(c)), decltype(pattern_arg(t)), decltype(pattern_arg(f))> {
28800 assert_is_lvalue_if_expr<C>();
28801 assert_is_lvalue_if_expr<T>();
28802 assert_is_lvalue_if_expr<F>();
28803 return {pattern_arg(c), pattern_arg(t), pattern_arg(f)};
28804}
28805
28806template<typename A, typename B>
28807struct BroadcastOp {
28808 struct pattern_tag {};
28809 A a;
28810 B lanes;
28811
28812 constexpr static uint32_t binds = bindings<A>::mask | bindings<B>::mask;
28813
28814 constexpr static IRNodeType min_node_type = IRNodeType::Broadcast;
28815 constexpr static IRNodeType max_node_type = IRNodeType::Broadcast;
28816
28817 constexpr static bool canonical = A::canonical && B::canonical;
28818
28819 template<uint32_t bound>
28820 HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
28821 if (e.node_type == Broadcast::_node_type) {
28822 const Broadcast &op = (const Broadcast &)e;
28823 if (a.template match<bound>(*op.value.get(), state) &&
28824 lanes.template match<bound>(op.lanes, state)) {
28825 return true;
28826 }
28827 }
28828 return false;
28829 }
28830
28831 template<uint32_t bound, typename A2, typename B2>
28832 HALIDE_ALWAYS_INLINE bool match(const BroadcastOp<A2, B2> &op, MatcherState &state) const noexcept {
28833 return (a.template match<bound>(unwrap(op.a), state) &&
28834 lanes.template match<bound | bindings<A>::mask>(unwrap(op.lanes), state));
28835 }
28836
28837 HALIDE_ALWAYS_INLINE
28838 Expr make(MatcherState &state, halide_type_t type_hint) const {
28839 halide_scalar_value_t lanes_val;
28840 halide_type_t ty;
28841 lanes.make_folded_const(lanes_val, ty, state);
28842 int32_t l = (int32_t)lanes_val.u.i64;
28843 type_hint.lanes /= l;
28844 Expr val = a.make(state, type_hint);
28845 if (l == 1) {
28846 return val;
28847 } else {
28848 return Broadcast::make(std::move(val), l);
28849 }
28850 }
28851
28852 constexpr static bool foldable = false;
28853
28854 template<typename A1 = A>
28855 HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept {
28856 halide_scalar_value_t lanes_val;
28857 halide_type_t lanes_ty;
28858 lanes.make_folded_const(lanes_val, lanes_ty, state);
28859 uint16_t l = (uint16_t)lanes_val.u.i64;
28860 a.make_folded_const(val, ty, state);
28861 ty.lanes = l | (ty.lanes & MatcherState::special_values_mask);
28862 }
28863};
28864
28865template<typename A, typename B>
28866inline std::ostream &operator<<(std::ostream &s, const BroadcastOp<A, B> &op) {
28867 s << "broadcast(" << op.a << ", " << op.lanes << ")";
28868 return s;
28869}
28870
28871template<typename A, typename B>
28872HALIDE_ALWAYS_INLINE auto broadcast(A &&a, B lanes) noexcept -> BroadcastOp<decltype(pattern_arg(a)), decltype(pattern_arg(lanes))> {
28873 assert_is_lvalue_if_expr<A>();
28874 return {pattern_arg(a), pattern_arg(lanes)};
28875}
28876
28877template<typename A, typename B, typename C>
28878struct RampOp {
28879 struct pattern_tag {};
28880 A a;
28881 B b;
28882 C lanes;
28883
28884 constexpr static uint32_t binds = bindings<A>::mask | bindings<B>::mask | bindings<C>::mask;
28885
28886 constexpr static IRNodeType min_node_type = IRNodeType::Ramp;
28887 constexpr static IRNodeType max_node_type = IRNodeType::Ramp;
28888
28889 constexpr static bool canonical = A::canonical && B::canonical && C::canonical;
28890
28891 template<uint32_t bound>
28892 HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
28893 if (e.node_type != Ramp::_node_type) {
28894 return false;
28895 }
28896 const Ramp &op = (const Ramp &)e;
28897 if (a.template match<bound>(*op.base.get(), state) &&
28898 b.template match<bound | bindings<A>::mask>(*op.stride.get(), state) &&
28899 lanes.template match<bound | bindings<A>::mask | bindings<B>::mask>(op.lanes, state)) {
28900 return true;
28901 } else {
28902 return false;
28903 }
28904 }
28905
28906 template<uint32_t bound, typename A2, typename B2, typename C2>
28907 HALIDE_ALWAYS_INLINE bool match(const RampOp<A2, B2, C2> &op, MatcherState &state) const noexcept {
28908 return (a.template match<bound>(unwrap(op.a), state) &&
28909 b.template match<bound | bindings<A>::mask>(unwrap(op.b), state) &&
28910 lanes.template match<bound | bindings<A>::mask | bindings<B>::mask>(unwrap(op.lanes), state));
28911 }
28912
28913 HALIDE_ALWAYS_INLINE
28914 Expr make(MatcherState &state, halide_type_t type_hint) const {
28915 halide_scalar_value_t lanes_val;
28916 halide_type_t ty;
28917 lanes.make_folded_const(lanes_val, ty, state);
28918 int32_t l = (int32_t)lanes_val.u.i64;
28919 type_hint.lanes /= l;
28920 Expr ea, eb;
28921 eb = b.make(state, type_hint);
28922 ea = a.make(state, eb.type());
28923 return Ramp::make(ea, eb, l);
28924 }
28925
28926 constexpr static bool foldable = false;
28927};
28928
28929template<typename A, typename B, typename C>
28930std::ostream &operator<<(std::ostream &s, const RampOp<A, B, C> &op) {
28931 s << "ramp(" << op.a << ", " << op.b << ", " << op.lanes << ")";
28932 return s;
28933}
28934
28935template<typename A, typename B, typename C>
28936HALIDE_ALWAYS_INLINE auto ramp(A &&a, B &&b, C &&c) noexcept -> RampOp<decltype(pattern_arg(a)), decltype(pattern_arg(b)), decltype(pattern_arg(c))> {
28937 assert_is_lvalue_if_expr<A>();
28938 assert_is_lvalue_if_expr<B>();
28939 assert_is_lvalue_if_expr<C>();
28940 return {pattern_arg(a), pattern_arg(b), pattern_arg(c)};
28941}
28942
28943template<typename A, typename B, VectorReduce::Operator reduce_op>
28944struct VectorReduceOp {
28945 struct pattern_tag {};
28946 A a;
28947 B lanes;
28948
28949 constexpr static uint32_t binds = bindings<A>::mask;
28950
28951 constexpr static IRNodeType min_node_type = IRNodeType::VectorReduce;
28952 constexpr static IRNodeType max_node_type = IRNodeType::VectorReduce;
28953 constexpr static bool canonical = A::canonical;
28954
28955 template<uint32_t bound>
28956 HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
28957 if (e.node_type == VectorReduce::_node_type) {
28958 const VectorReduce &op = (const VectorReduce &)e;
28959 if (op.op == reduce_op &&
28960 a.template match<bound>(*op.value.get(), state) &&
28961 lanes.template match<bound | bindings<A>::mask>(op.type.lanes(), state)) {
28962 return true;
28963 }
28964 }
28965 return false;
28966 }
28967
28968 template<uint32_t bound, typename A2, typename B2, VectorReduce::Operator reduce_op_2>
28969 HALIDE_ALWAYS_INLINE bool match(const VectorReduceOp<A2, B2, reduce_op_2> &op, MatcherState &state) const noexcept {
28970 return (reduce_op == reduce_op_2 &&
28971 a.template match<bound>(unwrap(op.a), state) &&
28972 lanes.template match<bound | bindings<A>::mask>(unwrap(op.lanes), state));
28973 }
28974
28975 HALIDE_ALWAYS_INLINE
28976 Expr make(MatcherState &state, halide_type_t type_hint) const {
28977 halide_scalar_value_t lanes_val;
28978 halide_type_t ty;
28979 lanes.make_folded_const(lanes_val, ty, state);
28980 int l = (int)lanes_val.u.i64;
28981 return VectorReduce::make(reduce_op, a.make(state, type_hint), l);
28982 }
28983
28984 constexpr static bool foldable = false;
28985};
28986
28987template<typename A, typename B, VectorReduce::Operator reduce_op>
28988inline std::ostream &operator<<(std::ostream &s, const VectorReduceOp<A, B, reduce_op> &op) {
28989 s << "vector_reduce(" << reduce_op << ", " << op.a << ", " << op.lanes << ")";
28990 return s;
28991}
28992
28993template<typename A, typename B>
28994HALIDE_ALWAYS_INLINE auto h_add(A &&a, B lanes) noexcept -> VectorReduceOp<decltype(pattern_arg(a)), decltype(pattern_arg(lanes)), VectorReduce::Add> {
28995 assert_is_lvalue_if_expr<A>();
28996 return {pattern_arg(a), pattern_arg(lanes)};
28997}
28998
28999template<typename A, typename B>
29000HALIDE_ALWAYS_INLINE auto h_min(A &&a, B lanes) noexcept -> VectorReduceOp<decltype(pattern_arg(a)), decltype(pattern_arg(lanes)), VectorReduce::Min> {
29001 assert_is_lvalue_if_expr<A>();
29002 return {pattern_arg(a), pattern_arg(lanes)};
29003}
29004
29005template<typename A, typename B>
29006HALIDE_ALWAYS_INLINE auto h_max(A &&a, B lanes) noexcept -> VectorReduceOp<decltype(pattern_arg(a)), decltype(pattern_arg(lanes)), VectorReduce::Max> {
29007 assert_is_lvalue_if_expr<A>();
29008 return {pattern_arg(a), pattern_arg(lanes)};
29009}
29010
29011template<typename A, typename B>
29012HALIDE_ALWAYS_INLINE auto h_and(A &&a, B lanes) noexcept -> VectorReduceOp<decltype(pattern_arg(a)), decltype(pattern_arg(lanes)), VectorReduce::And> {
29013 assert_is_lvalue_if_expr<A>();
29014 return {pattern_arg(a), pattern_arg(lanes)};
29015}
29016
29017template<typename A, typename B>
29018HALIDE_ALWAYS_INLINE auto h_or(A &&a, B lanes) noexcept -> VectorReduceOp<decltype(pattern_arg(a)), decltype(pattern_arg(lanes)), VectorReduce::Or> {
29019 assert_is_lvalue_if_expr<A>();
29020 return {pattern_arg(a), pattern_arg(lanes)};
29021}
29022
29023template<typename A>
29024struct NegateOp {
29025 struct pattern_tag {};
29026 A a;
29027
29028 constexpr static uint32_t binds = bindings<A>::mask;
29029
29030 constexpr static IRNodeType min_node_type = IRNodeType::Sub;
29031 constexpr static IRNodeType max_node_type = IRNodeType::Sub;
29032
29033 constexpr static bool canonical = A::canonical;
29034
29035 template<uint32_t bound>
29036 HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
29037 if (e.node_type != Sub::_node_type) {
29038 return false;
29039 }
29040 const Sub &op = (const Sub &)e;
29041 return (a.template match<bound>(*op.b.get(), state) &&
29042 is_const_zero(op.a));
29043 }
29044
29045 template<uint32_t bound, typename A2>
29046 HALIDE_ALWAYS_INLINE bool match(NegateOp<A2> &&p, MatcherState &state) const noexcept {
29047 return a.template match<bound>(unwrap(p.a), state);
29048 }
29049
29050 HALIDE_ALWAYS_INLINE
29051 Expr make(MatcherState &state, halide_type_t type_hint) const {
29052 Expr ea = a.make(state, type_hint);
29053 Expr z = make_zero(ea.type());
29054 return Sub::make(std::move(z), std::move(ea));
29055 }
29056
29057 constexpr static bool foldable = A::foldable;
29058
29059 template<typename A1 = A>
29060 HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept {
29061 a.make_folded_const(val, ty, state);
29062 int dead_bits = 64 - ty.bits;
29063 switch (ty.code) {
29064 case halide_type_int:
29065 if (ty.bits >= 32 && val.u.u64 && (val.u.u64 << (65 - ty.bits)) == 0) {
29066 // Trying to negate the most negative signed int for a no-overflow type.
29067 ty.lanes |= MatcherState::signed_integer_overflow;
29068 } else {
29069 // Negate, drop the high bits, and then sign-extend them back
29070 val.u.i64 = int64_t(uint64_t(-val.u.i64) << dead_bits) >> dead_bits;
29071 }
29072 break;
29073 case halide_type_uint:
29074 val.u.u64 = ((-val.u.u64) << dead_bits) >> dead_bits;
29075 break;
29076 case halide_type_float:
29077 case halide_type_bfloat:
29078 val.u.f64 = -val.u.f64;
29079 break;
29080 default:
29081 // unreachable
29082 ;
29083 }
29084 }
29085};
29086
29087template<typename A>
29088std::ostream &operator<<(std::ostream &s, const NegateOp<A> &op) {
29089 s << "-" << op.a;
29090 return s;
29091}
29092
29093template<typename A>
29094HALIDE_ALWAYS_INLINE auto operator-(A &&a) noexcept -> NegateOp<decltype(pattern_arg(a))> {
29095 assert_is_lvalue_if_expr<A>();
29096 return {pattern_arg(a)};
29097}
29098
29099template<typename A>
29100HALIDE_ALWAYS_INLINE auto negate(A &&a) -> decltype(IRMatcher::operator-(a)) {
29101 assert_is_lvalue_if_expr<A>();
29102 return IRMatcher::operator-(a);
29103}
29104
29105template<typename A>
29106struct CastOp {
29107 struct pattern_tag {};
29108 Type t;
29109 A a;
29110
29111 constexpr static uint32_t binds = bindings<A>::mask;
29112
29113 constexpr static IRNodeType min_node_type = IRNodeType::Cast;
29114 constexpr static IRNodeType max_node_type = IRNodeType::Cast;
29115 constexpr static bool canonical = A::canonical;
29116
29117 template<uint32_t bound>
29118 HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
29119 if (e.node_type != Cast::_node_type) {
29120 return false;
29121 }
29122 const Cast &op = (const Cast &)e;
29123 return (e.type == t &&
29124 a.template match<bound>(*op.value.get(), state));
29125 }
29126 template<uint32_t bound, typename A2>
29127 HALIDE_ALWAYS_INLINE bool match(const CastOp<A2> &op, MatcherState &state) const noexcept {
29128 return t == op.t && a.template match<bound>(unwrap(op.a), state);
29129 }
29130
29131 HALIDE_ALWAYS_INLINE
29132 Expr make(MatcherState &state, halide_type_t type_hint) const {
29133 return cast(t, a.make(state, {}));
29134 }
29135
29136 constexpr static bool foldable = false;
29137};
29138
29139template<typename A>
29140std::ostream &operator<<(std::ostream &s, const CastOp<A> &op) {
29141 s << "cast(" << op.t << ", " << op.a << ")";
29142 return s;
29143}
29144
29145template<typename A>
29146HALIDE_ALWAYS_INLINE auto cast(halide_type_t t, A &&a) noexcept -> CastOp<decltype(pattern_arg(a))> {
29147 assert_is_lvalue_if_expr<A>();
29148 return {t, pattern_arg(a)};
29149}
29150
29151template<typename A>
29152struct Fold {
29153 struct pattern_tag {};
29154 A a;
29155
29156 constexpr static uint32_t binds = bindings<A>::mask;
29157
29158 constexpr static IRNodeType min_node_type = IRNodeType::IntImm;
29159 constexpr static IRNodeType max_node_type = IRNodeType::FloatImm;
29160 constexpr static bool canonical = true;
29161
29162 HALIDE_ALWAYS_INLINE
29163 Expr make(MatcherState &state, halide_type_t type_hint) const noexcept {
29164 halide_scalar_value_t c;
29165 halide_type_t ty = type_hint;
29166 a.make_folded_const(c, ty, state);
29167
29168 // The result of the fold may have an underspecified type
29169 // (e.g. because it's from an int literal). Make the type code
29170 // and bits match the required type, if there is one (we can
29171 // tell from the bits field).
29172 if (type_hint.bits) {
29173 if (((int)ty.code == (int)halide_type_int) &&
29174 ((int)type_hint.code == (int)halide_type_float)) {
29175 int64_t x = c.u.i64;
29176 c.u.f64 = (double)x;
29177 }
29178 ty.code = type_hint.code;
29179 ty.bits = type_hint.bits;
29180 }
29181
29182 Expr e = make_const_expr(c, ty);
29183 return e;
29184 }
29185
29186 constexpr static bool foldable = A::foldable;
29187
29188 template<typename A1 = A>
29189 HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept {
29190 a.make_folded_const(val, ty, state);
29191 }
29192};
29193
29194template<typename A>
29195HALIDE_ALWAYS_INLINE auto fold(A &&a) noexcept -> Fold<decltype(pattern_arg(a))> {
29196 assert_is_lvalue_if_expr<A>();
29197 return {pattern_arg(a)};
29198}
29199
29200template<typename A>
29201std::ostream &operator<<(std::ostream &s, const Fold<A> &op) {
29202 s << "fold(" << op.a << ")";
29203 return s;
29204}
29205
29206template<typename A>
29207struct Overflows {
29208 struct pattern_tag {};
29209 A a;
29210
29211 constexpr static uint32_t binds = bindings<A>::mask;
29212
29213 // This rule is a predicate, so it always evaluates to a boolean,
29214 // which has IRNodeType UIntImm
29215 constexpr static IRNodeType min_node_type = IRNodeType::UIntImm;
29216 constexpr static IRNodeType max_node_type = IRNodeType::UIntImm;
29217 constexpr static bool canonical = true;
29218
29219 constexpr static bool foldable = A::foldable;
29220
29221 template<typename A1 = A>
29222 HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept {
29223 a.make_folded_const(val, ty, state);
29224 ty.code = halide_type_uint;
29225 ty.bits = 64;
29226 val.u.u64 = (ty.lanes & MatcherState::special_values_mask) != 0;
29227 ty.lanes = 1;
29228 }
29229};
29230
29231template<typename A>
29232HALIDE_ALWAYS_INLINE auto overflows(A &&a) noexcept -> Overflows<decltype(pattern_arg(a))> {
29233 assert_is_lvalue_if_expr<A>();
29234 return {pattern_arg(a)};
29235}
29236
29237template<typename A>
29238std::ostream &operator<<(std::ostream &s, const Overflows<A> &op) {
29239 s << "overflows(" << op.a << ")";
29240 return s;
29241}
29242
29243struct Overflow {
29244 struct pattern_tag {};
29245
29246 constexpr static uint32_t binds = 0;
29247
29248 // Overflow is an intrinsic, represented as a Call node
29249 constexpr static IRNodeType min_node_type = IRNodeType::Call;
29250 constexpr static IRNodeType max_node_type = IRNodeType::Call;
29251 constexpr static bool canonical = true;
29252
29253 template<uint32_t bound>
29254 HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
29255 if (e.node_type != Call::_node_type) {
29256 return false;
29257 }
29258 const Call &op = (const Call &)e;
29259 return (op.is_intrinsic(Call::signed_integer_overflow));
29260 }
29261
29262 HALIDE_ALWAYS_INLINE
29263 Expr make(MatcherState &state, halide_type_t type_hint) const {
29264 type_hint.lanes |= MatcherState::signed_integer_overflow;
29265 return make_const_special_expr(type_hint);
29266 }
29267
29268 constexpr static bool foldable = true;
29269
29270 HALIDE_ALWAYS_INLINE
29271 void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept {
29272 val.u.u64 = 0;
29273 ty.lanes |= MatcherState::signed_integer_overflow;
29274 }
29275};
29276
29277inline std::ostream &operator<<(std::ostream &s, const Overflow &op) {
29278 s << "overflow()";
29279 return s;
29280}
29281
29282template<typename A>
29283struct IsConst {
29284 struct pattern_tag {};
29285
29286 constexpr static uint32_t binds = bindings<A>::mask;
29287
29288 // This rule is a boolean-valued predicate. Bools have type UIntImm.
29289 constexpr static IRNodeType min_node_type = IRNodeType::UIntImm;
29290 constexpr static IRNodeType max_node_type = IRNodeType::UIntImm;
29291 constexpr static bool canonical = true;
29292
29293 A a;
29294 bool check_v;
29295 int64_t v;
29296
29297 constexpr static bool foldable = true;
29298
29299 template<typename A1 = A>
29300 HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept {
29301 Expr e = a.make(state, {});
29302 ty.code = halide_type_uint;
29303 ty.bits = 64;
29304 ty.lanes = 1;
29305 if (check_v) {
29306 val.u.u64 = ::Halide::Internal::is_const(e, v) ? 1 : 0;
29307 } else {
29308 val.u.u64 = ::Halide::Internal::is_const(e) ? 1 : 0;
29309 }
29310 }
29311};
29312
29313template<typename A>
29314HALIDE_ALWAYS_INLINE auto is_const(A &&a) noexcept -> IsConst<decltype(pattern_arg(a))> {
29315 assert_is_lvalue_if_expr<A>();
29316 return {pattern_arg(a), false, 0};
29317}
29318
29319template<typename A>
29320HALIDE_ALWAYS_INLINE auto is_const(A &&a, int64_t value) noexcept -> IsConst<decltype(pattern_arg(a))> {
29321 assert_is_lvalue_if_expr<A>();
29322 return {pattern_arg(a), true, value};
29323}
29324
29325template<typename A>
29326std::ostream &operator<<(std::ostream &s, const IsConst<A> &op) {
29327 if (op.check_v) {
29328 s << "is_const(" << op.a << ")";
29329 } else {
29330 s << "is_const(" << op.a << ", " << op.v << ")";
29331 }
29332 return s;
29333}
29334
29335template<typename A, typename Prover>
29336struct CanProve {
29337 struct pattern_tag {};
29338 A a;
29339 Prover *prover; // An existing simplifying mutator
29340
29341 constexpr static uint32_t binds = bindings<A>::mask;
29342
29343 // This rule is a boolean-valued predicate. Bools have type UIntImm.
29344 constexpr static IRNodeType min_node_type = IRNodeType::UIntImm;
29345 constexpr static IRNodeType max_node_type = IRNodeType::UIntImm;
29346 constexpr static bool canonical = true;
29347
29348 constexpr static bool foldable = true;
29349
29350 // Includes a raw call to an inlined make method, so don't inline.
29351 HALIDE_NEVER_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const {
29352 Expr condition = a.make(state, {});
29353 condition = prover->mutate(condition, nullptr);
29354 val.u.u64 = is_const_one(condition);
29355 ty.code = halide_type_uint;
29356 ty.bits = 1;
29357 ty.lanes = condition.type().lanes();
29358 }
29359};
29360
29361template<typename A, typename Prover>
29362HALIDE_ALWAYS_INLINE auto can_prove(A &&a, Prover *p) noexcept -> CanProve<decltype(pattern_arg(a)), Prover> {
29363 assert_is_lvalue_if_expr<A>();
29364 return {pattern_arg(a), p};
29365}
29366
29367template<typename A, typename Prover>
29368std::ostream &operator<<(std::ostream &s, const CanProve<A, Prover> &op) {
29369 s << "can_prove(" << op.a << ")";
29370 return s;
29371}
29372
29373template<typename A>
29374struct IsFloat {
29375 struct pattern_tag {};
29376 A a;
29377
29378 constexpr static uint32_t binds = bindings<A>::mask;
29379
29380 // This rule is a boolean-valued predicate. Bools have type UIntImm.
29381 constexpr static IRNodeType min_node_type = IRNodeType::UIntImm;
29382 constexpr static IRNodeType max_node_type = IRNodeType::UIntImm;
29383 constexpr static bool canonical = true;
29384
29385 constexpr static bool foldable = true;
29386
29387 HALIDE_ALWAYS_INLINE
29388 void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const {
29389 // a is almost certainly a very simple pattern (e.g. a wild), so just inline the make method.
29390 Type t = a.make(state, {}).type();
29391 val.u.u64 = t.is_float();
29392 ty.code = halide_type_uint;
29393 ty.bits = 1;
29394 ty.lanes = t.lanes();
29395 }
29396};
29397
29398template<typename A>
29399HALIDE_ALWAYS_INLINE auto is_float(A &&a) noexcept -> IsFloat<decltype(pattern_arg(a))> {
29400 assert_is_lvalue_if_expr<A>();
29401 return {pattern_arg(a)};
29402}
29403
29404template<typename A>
29405std::ostream &operator<<(std::ostream &s, const IsFloat<A> &op) {
29406 s << "is_float(" << op.a << ")";
29407 return s;
29408}
29409
29410template<typename A>
29411struct IsInt {
29412 struct pattern_tag {};
29413 A a;
29414 int bits;
29415
29416 constexpr static uint32_t binds = bindings<A>::mask;
29417
29418 // This rule is a boolean-valued predicate. Bools have type UIntImm.
29419 constexpr static IRNodeType min_node_type = IRNodeType::UIntImm;
29420 constexpr static IRNodeType max_node_type = IRNodeType::UIntImm;
29421 constexpr static bool canonical = true;
29422
29423 constexpr static bool foldable = true;
29424
29425 HALIDE_ALWAYS_INLINE
29426 void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const {
29427 // a is almost certainly a very simple pattern (e.g. a wild), so just inline the make method.
29428 Type t = a.make(state, {}).type();
29429 val.u.u64 = t.is_int() && (bits == 0 || t.bits() == bits);
29430 ty.code = halide_type_uint;
29431 ty.bits = 1;
29432 ty.lanes = t.lanes();
29433 }
29434};
29435
29436template<typename A>
29437HALIDE_ALWAYS_INLINE auto is_int(A &&a, int bits = 0) noexcept -> IsInt<decltype(pattern_arg(a))> {
29438 assert_is_lvalue_if_expr<A>();
29439 return {pattern_arg(a), bits};
29440}
29441
29442template<typename A>
29443std::ostream &operator<<(std::ostream &s, const IsInt<A> &op) {
29444 s << "is_int(" << op.a;
29445 if (op.bits > 0) {
29446 s << ", " << op.bits;
29447 }
29448 s << ")";
29449 return s;
29450}
29451
29452template<typename A>
29453struct IsUInt {
29454 struct pattern_tag {};
29455 A a;
29456 int bits;
29457
29458 constexpr static uint32_t binds = bindings<A>::mask;
29459
29460 // This rule is a boolean-valued predicate. Bools have type UIntImm.
29461 constexpr static IRNodeType min_node_type = IRNodeType::UIntImm;
29462 constexpr static IRNodeType max_node_type = IRNodeType::UIntImm;
29463 constexpr static bool canonical = true;
29464
29465 constexpr static bool foldable = true;
29466
29467 HALIDE_ALWAYS_INLINE
29468 void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const {
29469 // a is almost certainly a very simple pattern (e.g. a wild), so just inline the make method.
29470 Type t = a.make(state, {}).type();
29471 val.u.u64 = t.is_uint() && (bits == 0 || t.bits() == bits);
29472 ty.code = halide_type_uint;
29473 ty.bits = 1;
29474 ty.lanes = t.lanes();
29475 }
29476};
29477
29478template<typename A>
29479HALIDE_ALWAYS_INLINE auto is_uint(A &&a, int bits = 0) noexcept -> IsUInt<decltype(pattern_arg(a))> {
29480 assert_is_lvalue_if_expr<A>();
29481 return {pattern_arg(a), bits};
29482}
29483
29484template<typename A>
29485std::ostream &operator<<(std::ostream &s, const IsUInt<A> &op) {
29486 s << "is_uint(" << op.a;
29487 if (op.bits > 0) {
29488 s << ", " << op.bits;
29489 }
29490 s << ")";
29491 return s;
29492}
29493
29494template<typename A>
29495struct IsScalar {
29496 struct pattern_tag {};
29497 A a;
29498
29499 constexpr static uint32_t binds = bindings<A>::mask;
29500
29501 // This rule is a boolean-valued predicate. Bools have type UIntImm.
29502 constexpr static IRNodeType min_node_type = IRNodeType::UIntImm;
29503 constexpr static IRNodeType max_node_type = IRNodeType::UIntImm;
29504 constexpr static bool canonical = true;
29505
29506 constexpr static bool foldable = true;
29507
29508 HALIDE_ALWAYS_INLINE
29509 void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const {
29510 // a is almost certainly a very simple pattern (e.g. a wild), so just inline the make method.
29511 Type t = a.make(state, {}).type();
29512 val.u.u64 = t.is_scalar();
29513 ty.code = halide_type_uint;
29514 ty.bits = 1;
29515 ty.lanes = t.lanes();
29516 }
29517};
29518
29519template<typename A>
29520HALIDE_ALWAYS_INLINE auto is_scalar(A &&a) noexcept -> IsScalar<decltype(pattern_arg(a))> {
29521 assert_is_lvalue_if_expr<A>();
29522 return {pattern_arg(a)};
29523}
29524
29525template<typename A>
29526struct IsMaxValue {
29527 struct pattern_tag {};
29528 A a;
29529
29530 constexpr static uint32_t binds = bindings<A>::mask;
29531
29532 // This rule is a boolean-valued predicate. Bools have type UIntImm.
29533 constexpr static IRNodeType min_node_type = IRNodeType::UIntImm;
29534 constexpr static IRNodeType max_node_type = IRNodeType::UIntImm;
29535 constexpr static bool canonical = true;
29536
29537 constexpr static bool foldable = true;
29538
29539 HALIDE_ALWAYS_INLINE
29540 void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const {
29541 // a is almost certainly a very simple pattern (e.g. a wild), so just inline the make method.
29542 a.make_folded_const(val, ty, state);
29543 const uint64_t max_bits = (uint64_t)(-1) >> (64 - ty.bits + (ty.code == halide_type_int));
29544 if (ty.code == halide_type_uint || ty.code == halide_type_int) {
29545 val.u.u64 = (val.u.u64 == max_bits);
29546 } else {
29547 val.u.u64 = 0;
29548 }
29549 ty.code = halide_type_uint;
29550 ty.bits = 1;
29551 }
29552};
29553
29554template<typename A>
29555HALIDE_ALWAYS_INLINE auto is_max_value(A &&a) noexcept -> IsMaxValue<decltype(pattern_arg(a))> {
29556 assert_is_lvalue_if_expr<A>();
29557 return {pattern_arg(a)};
29558}
29559
29560template<typename A>
29561struct IsMinValue {
29562 struct pattern_tag {};
29563 A a;
29564
29565 constexpr static uint32_t binds = bindings<A>::mask;
29566
29567 // This rule is a boolean-valued predicate. Bools have type UIntImm.
29568 constexpr static IRNodeType min_node_type = IRNodeType::UIntImm;
29569 constexpr static IRNodeType max_node_type = IRNodeType::UIntImm;
29570 constexpr static bool canonical = true;
29571
29572 constexpr static bool foldable = true;
29573
29574 HALIDE_ALWAYS_INLINE
29575 void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const {
29576 // a is almost certainly a very simple pattern (e.g. a wild), so just inline the make method.
29577 a.make_folded_const(val, ty, state);
29578 if (ty.code == halide_type_int) {
29579 const uint64_t min_bits = (uint64_t)(-1) << (ty.bits - 1);
29580 val.u.u64 = (val.u.u64 == min_bits);
29581 } else if (ty.code == halide_type_uint) {
29582 val.u.u64 = (val.u.u64 == 0);
29583 } else {
29584 val.u.u64 = 0;
29585 }
29586 ty.code = halide_type_uint;
29587 ty.bits = 1;
29588 }
29589};
29590
29591template<typename A>
29592HALIDE_ALWAYS_INLINE auto is_min_value(A &&a) noexcept -> IsMinValue<decltype(pattern_arg(a))> {
29593 assert_is_lvalue_if_expr<A>();
29594 return {pattern_arg(a)};
29595}
29596
29597template<typename A>
29598std::ostream &operator<<(std::ostream &s, const IsScalar<A> &op) {
29599 s << "is_scalar(" << op.a << ")";
29600 return s;
29601}
29602
29603// Verify properties of each rewrite rule. Currently just fuzz tests them.
29604template<typename Before,
29605 typename After,
29606 typename Predicate,
29607 typename = typename std::enable_if<std::decay<Before>::type::foldable &&
29608 std::decay<After>::type::foldable>::type>
29609HALIDE_NEVER_INLINE void fuzz_test_rule(Before &&before, After &&after, Predicate &&pred,
29610 halide_type_t wildcard_type, halide_type_t output_type) noexcept {
29611
29612 // We only validate the rules in the scalar case
29613 wildcard_type.lanes = output_type.lanes = 1;
29614
29615 // Track which types this rule has been tested for before
29616 static std::set<uint32_t> tested;
29617
29618 if (!tested.insert(reinterpret_bits<uint32_t>(wildcard_type)).second) {
29619 return;
29620 }
29621
29622 // Print it in a form where it can be piped into a python/z3 validator
29623 debug(0) << "validate('" << before << "', '" << after << "', '" << pred << "', " << Type(wildcard_type) << ", " << Type(output_type) << ")\n";
29624
29625 // Substitute some random constants into the before and after
29626 // expressions and see if the rule holds true. This should catch
29627 // silly errors, but not necessarily corner cases.
29628 static std::mt19937_64 rng(0);
29629 MatcherState state;
29630
29631 Expr exprs[max_wild];
29632
29633 for (int trials = 0; trials < 100; trials++) {
29634 // We want to test small constants more frequently than
29635 // large ones, otherwise we'll just get coverage of
29636 // overflow rules.
29637 int shift = (int)(rng() & (wildcard_type.bits - 1));
29638
29639 for (int i = 0; i < max_wild; i++) {
29640 // Bind all the exprs and constants
29641 switch (wildcard_type.code) {
29642 case halide_type_uint: {
29643 // Normalize to the type's range by adding zero
29644 uint64_t val = constant_fold_bin_op<Add>(wildcard_type, (uint64_t)rng() >> shift, 0);
29645 state.set_bound_const(i, val, wildcard_type);
29646 val = constant_fold_bin_op<Add>(wildcard_type, (uint64_t)rng() >> shift, 0);
29647 exprs[i] = make_const(wildcard_type, val);
29648 state.set_binding(i, *exprs[i].get());
29649 } break;
29650 case halide_type_int: {
29651 int64_t val = constant_fold_bin_op<Add>(wildcard_type, (int64_t)rng() >> shift, 0);
29652 state.set_bound_const(i, val, wildcard_type);
29653 val = constant_fold_bin_op<Add>(wildcard_type, (int64_t)rng() >> shift, 0);
29654 exprs[i] = make_const(wildcard_type, val);
29655 } break;
29656 case halide_type_float:
29657 case halide_type_bfloat: {
29658 // Use a very narrow range of precise floats, so
29659 // that none of the rules a human is likely to
29660 // write have instabilities.
29661 double val = ((int64_t)(rng() & 15) - 8) / 2.0;
29662 state.set_bound_const(i, val, wildcard_type);
29663 val = ((int64_t)(rng() & 15) - 8) / 2.0;
29664 exprs[i] = make_const(wildcard_type, val);
29665 } break;
29666 default:
29667 return; // Don't care about handles
29668 }
29669 state.set_binding(i, *exprs[i].get());
29670 }
29671
29672 halide_scalar_value_t val_pred, val_before, val_after;
29673 halide_type_t type = output_type;
29674 if (!evaluate_predicate(pred, state)) {
29675 continue;
29676 }
29677 before.make_folded_const(val_before, type, state);
29678 uint16_t lanes = type.lanes;
29679 after.make_folded_const(val_after, type, state);
29680 lanes |= type.lanes;
29681
29682 if (lanes & MatcherState::special_values_mask) {
29683 continue;
29684 }
29685
29686 bool ok = true;
29687 switch (output_type.code) {
29688 case halide_type_uint:
29689 // Compare normalized representations
29690 ok &= (constant_fold_bin_op<Add>(output_type, val_before.u.u64, 0) ==
29691 constant_fold_bin_op<Add>(output_type, val_after.u.u64, 0));
29692 break;
29693 case halide_type_int:
29694 ok &= (constant_fold_bin_op<Add>(output_type, val_before.u.i64, 0) ==
29695 constant_fold_bin_op<Add>(output_type, val_after.u.i64, 0));
29696 break;
29697 case halide_type_float:
29698 case halide_type_bfloat: {
29699 double error = std::abs(val_before.u.f64 - val_after.u.f64);
29700 // We accept an equal bit pattern (e.g. inf vs inf),
29701 // a small floating point difference, or turning a nan into not-a-nan.
29702 ok &= (error < 0.01 ||
29703 val_before.u.u64 == val_after.u.u64 ||
29704 std::isnan(val_before.u.f64));
29705 break;
29706 }
29707 default:
29708 return;
29709 }
29710
29711 if (!ok) {
29712 debug(0) << "Fails with values:\n";
29713 for (int i = 0; i < max_wild; i++) {
29714 halide_scalar_value_t val;
29715 state.get_bound_const(i, val, wildcard_type);
29716 debug(0) << " c" << i << ": " << make_const_expr(val, wildcard_type) << "\n";
29717 }
29718 for (int i = 0; i < max_wild; i++) {
29719 debug(0) << " _" << i << ": " << Expr(state.get_binding(i)) << "\n";
29720 }
29721 debug(0) << " Before: " << make_const_expr(val_before, output_type) << "\n";
29722 debug(0) << " After: " << make_const_expr(val_after, output_type) << "\n";
29723 debug(0) << val_before.u.u64 << " " << val_after.u.u64 << "\n";
29724 internal_error;
29725 }
29726 }
29727}
29728
29729template<typename Before,
29730 typename After,
29731 typename Predicate,
29732 typename = typename std::enable_if<!(std::decay<Before>::type::foldable &&
29733 std::decay<After>::type::foldable)>::type>
29734HALIDE_ALWAYS_INLINE void fuzz_test_rule(Before &&before, After &&after, Predicate &&pred,
29735 halide_type_t, halide_type_t, int dummy = 0) noexcept {
29736 // We can't verify rewrite rules that can't be constant-folded.
29737}
29738
29739HALIDE_ALWAYS_INLINE
29740bool evaluate_predicate(bool x, MatcherState &) noexcept {
29741 return x;
29742}
29743
29744template<typename Pattern,
29745 typename = typename enable_if_pattern<Pattern>::type>
29746HALIDE_ALWAYS_INLINE bool evaluate_predicate(Pattern p, MatcherState &state) {
29747 halide_scalar_value_t c;
29748 halide_type_t ty = halide_type_of<bool>();
29749 p.make_folded_const(c, ty, state);
29750 // Overflow counts as a failed predicate
29751 return (c.u.u64 != 0) && ((ty.lanes & MatcherState::special_values_mask) == 0);
29752}
29753
29754// #defines for testing
29755
29756// Print all successful or failed matches
29757#define HALIDE_DEBUG_MATCHED_RULES 0
29758#define HALIDE_DEBUG_UNMATCHED_RULES 0
29759
29760// Set to true if you want to fuzz test every rewrite passed to
29761// operator() to ensure the input and the output have the same value
29762// for lots of random values of the wildcards. Run
29763// correctness_simplify with this on.
29764#define HALIDE_FUZZ_TEST_RULES 0
29765
29766template<typename Instance>
29767struct Rewriter {
29768 Instance instance;
29769 Expr result;
29770 MatcherState state;
29771 halide_type_t output_type, wildcard_type;
29772 bool validate;
29773
29774 HALIDE_ALWAYS_INLINE
29775 Rewriter(Instance instance, halide_type_t ot, halide_type_t wt)
29776 : instance(std::move(instance)), output_type(ot), wildcard_type(wt) {
29777 }
29778
29779 template<typename After>
29780 HALIDE_NEVER_INLINE void build_replacement(After after) {
29781 result = after.make(state, output_type);
29782 }
29783
29784 template<typename Before,
29785 typename After,
29786 typename = typename enable_if_pattern<Before>::type,
29787 typename = typename enable_if_pattern<After>::type>
29788 HALIDE_ALWAYS_INLINE bool operator()(Before before, After after) {
29789 static_assert((Before::binds & After::binds) == After::binds, "Rule result uses unbound values");
29790 static_assert(Before::canonical, "LHS of rewrite rule should be in canonical form");
29791 static_assert(After::canonical, "RHS of rewrite rule should be in canonical form");
29792#if HALIDE_FUZZ_TEST_RULES
29793 fuzz_test_rule(before, after, true, wildcard_type, output_type);
29794#endif
29795 if (before.template match<0>(unwrap(instance), state)) {
29796#if HALIDE_DEBUG_MATCHED_RULES
29797 debug(0) << instance << " -> " << result << " via " << before << " -> " << after << "\n";
29798#endif
29799 build_replacement(after);
29800 return true;
29801 } else {
29802#if HALIDE_DEBUG_UNMATCHED_RULES
29803 debug(0) << instance << " does not match " << before << "\n";
29804#endif
29805 return false;
29806 }
29807 }
29808
29809 template<typename Before,
29810 typename = typename enable_if_pattern<Before>::type>
29811 HALIDE_ALWAYS_INLINE bool operator()(Before before, const Expr &after) noexcept {
29812 static_assert(Before::canonical, "LHS of rewrite rule should be in canonical form");
29813 if (before.template match<0>(unwrap(instance), state)) {
29814 result = after;
29815#if HALIDE_DEBUG_MATCHED_RULES
29816 debug(0) << instance << " -> " << result << " via " << before << " -> " << after << "\n";
29817#endif
29818 return true;
29819 } else {
29820#if HALIDE_DEBUG_UNMATCHED_RULES
29821 debug(0) << instance << " does not match " << before << "\n";
29822#endif
29823 return false;
29824 }
29825 }
29826
29827 template<typename Before,
29828 typename = typename enable_if_pattern<Before>::type>
29829 HALIDE_ALWAYS_INLINE bool operator()(Before before, int64_t after) noexcept {
29830 static_assert(Before::canonical, "LHS of rewrite rule should be in canonical form");
29831#if HALIDE_FUZZ_TEST_RULES
29832 fuzz_test_rule(before, IntLiteral(after), true, wildcard_type, output_type);
29833#endif
29834 if (before.template match<0>(unwrap(instance), state)) {
29835 result = make_const(output_type, after);
29836#if HALIDE_DEBUG_MATCHED_RULES
29837 debug(0) << instance << " -> " << result << " via " << before << " -> " << after << "\n";
29838#endif
29839 return true;
29840 } else {
29841#if HALIDE_DEBUG_UNMATCHED_RULES
29842 debug(0) << instance << " does not match " << before << "\n";
29843#endif
29844 return false;
29845 }
29846 }
29847
29848 template<typename Before,
29849 typename After,
29850 typename Predicate,
29851 typename = typename enable_if_pattern<Before>::type,
29852 typename = typename enable_if_pattern<After>::type,
29853 typename = typename enable_if_pattern<Predicate>::type>
29854 HALIDE_ALWAYS_INLINE bool operator()(Before before, After after, Predicate pred) {
29855 static_assert(Predicate::foldable, "Predicates must consist only of operations that can constant-fold");
29856 static_assert((Before::binds & After::binds) == After::binds, "Rule result uses unbound values");
29857 static_assert((Before::binds & Predicate::binds) == Predicate::binds, "Rule predicate uses unbound values");
29858 static_assert(Before::canonical, "LHS of rewrite rule should be in canonical form");
29859 static_assert(After::canonical, "RHS of rewrite rule should be in canonical form");
29860
29861#if HALIDE_FUZZ_TEST_RULES
29862 fuzz_test_rule(before, after, pred, wildcard_type, output_type);
29863#endif
29864 if (before.template match<0>(unwrap(instance), state) &&
29865 evaluate_predicate(pred, state)) {
29866#if HALIDE_DEBUG_MATCHED_RULES
29867 debug(0) << instance << " -> " << result << " via " << before << " -> " << after << " when " << pred << "\n";
29868#endif
29869 build_replacement(after);
29870 return true;
29871 } else {
29872#if HALIDE_DEBUG_UNMATCHED_RULES
29873 debug(0) << instance << " does not match " << before << "\n";
29874#endif
29875 return false;
29876 }
29877 }
29878
29879 template<typename Before,
29880 typename Predicate,
29881 typename = typename enable_if_pattern<Before>::type,
29882 typename = typename enable_if_pattern<Predicate>::type>
29883 HALIDE_ALWAYS_INLINE bool operator()(Before before, const Expr &after, Predicate pred) {
29884 static_assert(Predicate::foldable, "Predicates must consist only of operations that can constant-fold");
29885 static_assert(Before::canonical, "LHS of rewrite rule should be in canonical form");
29886
29887 if (before.template match<0>(unwrap(instance), state) &&
29888 evaluate_predicate(pred, state)) {
29889 result = after;
29890#if HALIDE_DEBUG_MATCHED_RULES
29891 debug(0) << instance << " -> " << result << " via " << before << " -> " << after << " when " << pred << "\n";
29892#endif
29893 return true;
29894 } else {
29895#if HALIDE_DEBUG_UNMATCHED_RULES
29896 debug(0) << instance << " does not match " << before << "\n";
29897#endif
29898 return false;
29899 }
29900 }
29901
29902 template<typename Before,
29903 typename Predicate,
29904 typename = typename enable_if_pattern<Before>::type,
29905 typename = typename enable_if_pattern<Predicate>::type>
29906 HALIDE_ALWAYS_INLINE bool operator()(Before before, int64_t after, Predicate pred) {
29907 static_assert(Predicate::foldable, "Predicates must consist only of operations that can constant-fold");
29908 static_assert(Before::canonical, "LHS of rewrite rule should be in canonical form");
29909#if HALIDE_FUZZ_TEST_RULES
29910 fuzz_test_rule(before, IntLiteral(after), pred, wildcard_type, output_type);
29911#endif
29912 if (before.template match<0>(unwrap(instance), state) &&
29913 evaluate_predicate(pred, state)) {
29914 result = make_const(output_type, after);
29915#if HALIDE_DEBUG_MATCHED_RULES
29916 debug(0) << instance << " -> " << result << " via " << before << " -> " << after << " when " << pred << "\n";
29917#endif
29918 return true;
29919 } else {
29920#if HALIDE_DEBUG_UNMATCHED_RULES
29921 debug(0) << instance << " does not match " << before << "\n";
29922#endif
29923 return false;
29924 }
29925 }
29926};
29927
29928/** Construct a rewriter for the given instance, which may be a pattern
29929 * with concrete expressions as leaves, or just an expression. The
29930 * second optional argument (wildcard_type) is a hint as to what the
29931 * type of the wildcards is likely to be. If omitted it uses the same
29932 * type as the expression itself. They are not required to be this
29933 * type, but the rule will only be tested for wildcards of that type
29934 * when testing is enabled.
29935 *
29936 * The rewriter can be used to check to see if the instance is one of
29937 * some number of patterns and if so rewrite it into another form,
29938 * using its operator() method. See Simplify.cpp for a bunch of
29939 * example usage.
29940 *
29941 * Important: Any Exprs in patterns are captured by reference, not by
29942 * value, so ensure they outlive the rewriter.
29943 */
29944// @{
29945template<typename Instance,
29946 typename = typename enable_if_pattern<Instance>::type>
29947HALIDE_ALWAYS_INLINE auto rewriter(Instance instance, halide_type_t output_type, halide_type_t wildcard_type) noexcept -> Rewriter<decltype(pattern_arg(instance))> {
29948 return {pattern_arg(instance), output_type, wildcard_type};
29949}
29950
29951template<typename Instance,
29952 typename = typename enable_if_pattern<Instance>::type>
29953HALIDE_ALWAYS_INLINE auto rewriter(Instance instance, halide_type_t output_type) noexcept -> Rewriter<decltype(pattern_arg(instance))> {
29954 return {pattern_arg(instance), output_type, output_type};
29955}
29956
29957HALIDE_ALWAYS_INLINE
29958auto rewriter(const Expr &e, halide_type_t wildcard_type) noexcept -> Rewriter<decltype(pattern_arg(e))> {
29959 return {pattern_arg(e), e.type(), wildcard_type};
29960}
29961
29962HALIDE_ALWAYS_INLINE
29963auto rewriter(const Expr &e) noexcept -> Rewriter<decltype(pattern_arg(e))> {
29964 return {pattern_arg(e), e.type(), e.type()};
29965}
29966// @}
29967
29968} // namespace IRMatcher
29969
29970} // namespace Internal
29971} // namespace Halide
29972
29973#endif
29974#ifndef HALIDE_IR_MUTATOR_H
29975#define HALIDE_IR_MUTATOR_H
29976
29977/** \file
29978 * Defines a base class for passes over the IR that modify it
29979 */
29980
29981#include <map>
29982
29983
29984namespace Halide {
29985namespace Internal {
29986
29987/** A base class for passes over the IR which modify it
29988 * (e.g. replacing a variable with a value (Substitute.h), or
29989 * constant-folding).
29990 *
29991 * Your mutator should override the visit() methods you care about and return
29992 * the new expression or stmt. The default implementations recursively
29993 * mutate their children. To mutate sub-expressions and sub-statements you
29994 * should override the mutate() method, which will dispatch to
29995 * the appropriate visit() method and then return the value of expr or
29996 * stmt after the call to visit.
29997 */
29998class IRMutator {
29999public:
30000 IRMutator() = default;
30001 virtual ~IRMutator() = default;
30002
30003 /** This is the main interface for using a mutator. Also call
30004 * these in your subclass to mutate sub-expressions and
30005 * sub-statements.
30006 */
30007 virtual Expr mutate(const Expr &expr);
30008 virtual Stmt mutate(const Stmt &stmt);
30009
30010protected:
30011 // ExprNode<> and StmtNode<> are allowed to call visit (to implement mutate_expr/mutate_stmt())
30012 template<typename T>
30013 friend struct ExprNode;
30014 template<typename T>
30015 friend struct StmtNode;
30016
30017 virtual Expr visit(const IntImm *);
30018 virtual Expr visit(const UIntImm *);
30019 virtual Expr visit(const FloatImm *);
30020 virtual Expr visit(const StringImm *);
30021 virtual Expr visit(const Cast *);
30022 virtual Expr visit(const Variable *);
30023 virtual Expr visit(const Add *);
30024 virtual Expr visit(const Sub *);
30025 virtual Expr visit(const Mul *);
30026 virtual Expr visit(const Div *);
30027 virtual Expr visit(const Mod *);
30028 virtual Expr visit(const Min *);
30029 virtual Expr visit(const Max *);
30030 virtual Expr visit(const EQ *);
30031 virtual Expr visit(const NE *);
30032 virtual Expr visit(const LT *);
30033 virtual Expr visit(const LE *);
30034 virtual Expr visit(const GT *);
30035 virtual Expr visit(const GE *);
30036 virtual Expr visit(const And *);
30037 virtual Expr visit(const Or *);
30038 virtual Expr visit(const Not *);
30039 virtual Expr visit(const Select *);
30040 virtual Expr visit(const Load *);
30041 virtual Expr visit(const Ramp *);
30042 virtual Expr visit(const Broadcast *);
30043 virtual Expr visit(const Call *);
30044 virtual Expr visit(const Let *);
30045 virtual Expr visit(const Shuffle *);
30046 virtual Expr visit(const VectorReduce *);
30047
30048 virtual Stmt visit(const LetStmt *);
30049 virtual Stmt visit(const AssertStmt *);
30050 virtual Stmt visit(const ProducerConsumer *);
30051 virtual Stmt visit(const For *);
30052 virtual Stmt visit(const Store *);
30053 virtual Stmt visit(const Provide *);
30054 virtual Stmt visit(const Allocate *);
30055 virtual Stmt visit(const Free *);
30056 virtual Stmt visit(const Realize *);
30057 virtual Stmt visit(const Block *);
30058 virtual Stmt visit(const IfThenElse *);
30059 virtual Stmt visit(const Evaluate *);
30060 virtual Stmt visit(const Prefetch *);
30061 virtual Stmt visit(const Acquire *);
30062 virtual Stmt visit(const Fork *);
30063 virtual Stmt visit(const Atomic *);
30064};
30065
30066/** A mutator that caches and reapplies previously-done mutations, so
30067 * that it can handle graphs of IR that have not had CSE done to
30068 * them. */
30069class IRGraphMutator : public IRMutator {
30070protected:
30071 std::map<Expr, Expr, ExprCompare> expr_replacements;
30072 std::map<Stmt, Stmt, Stmt::Compare> stmt_replacements;
30073
30074public:
30075 Stmt mutate(const Stmt &s) override;
30076 Expr mutate(const Expr &e) override;
30077};
30078
30079/** A helper function for mutator-like things to mutate regions */
30080template<typename Mutator, typename... Args>
30081std::pair<Region, bool> mutate_region(Mutator *mutator, const Region &bounds, Args &&...args) {
30082 Region new_bounds(bounds.size());
30083 bool bounds_changed = false;
30084
30085 for (size_t i = 0; i < bounds.size(); i++) {
30086 Expr old_min = bounds[i].min;
30087 Expr old_extent = bounds[i].extent;
30088 Expr new_min = mutator->mutate(old_min, std::forward<Args>(args)...);
30089 Expr new_extent = mutator->mutate(old_extent, std::forward<Args>(args)...);
30090 if (!new_min.same_as(old_min)) {
30091 bounds_changed = true;
30092 }
30093 if (!new_extent.same_as(old_extent)) {
30094 bounds_changed = true;
30095 }
30096 new_bounds[i] = Range(new_min, new_extent);
30097 }
30098 return {new_bounds, bounds_changed};
30099}
30100
30101} // namespace Internal
30102} // namespace Halide
30103
30104#endif
30105#ifndef HALIDE_LERP_H
30106#define HALIDE_LERP_H
30107
30108/** \file
30109 * Defines methods for converting a lerp intrinsic into Halide IR.
30110 */
30111
30112
30113namespace Halide {
30114namespace Internal {
30115
30116/** Build Halide IR that computes a lerp. Use by codegen targets that
30117 * don't have a native lerp. */
30118Expr lower_lerp(Expr zero_val, Expr one_val, const Expr &weight);
30119
30120} // namespace Internal
30121} // namespace Halide
30122
30123#endif
30124#ifndef HALIDE_LICM_H
30125#define HALIDE_LICM_H
30126
30127/** \file
30128 * Methods for lifting loop invariants out of inner loops.
30129 */
30130
30131
30132namespace Halide {
30133namespace Internal {
30134
30135/** Hoist loop-invariants out of inner loops. This is especially
30136 * important in cases where LLVM would not do it for us
30137 * automatically. For example, it hoists loop invariants out of cuda
30138 * kernels. */
30139Stmt hoist_loop_invariant_values(Stmt);
30140
30141/** Just hoist loop-invariant if statements as far up as
30142 * possible. Does not lift other values. It's useful to run this
30143 * earlier in lowering to simplify the IR. */
30144Stmt hoist_loop_invariant_if_statements(Stmt);
30145
30146} // namespace Internal
30147} // namespace Halide
30148
30149#endif
30150#ifndef HALIDE_LLVM_OUTPUTS_H
30151#define HALIDE_LLVM_OUTPUTS_H
30152
30153/** \file
30154 *
30155 */
30156
30157#include <memory>
30158#include <string>
30159#include <vector>
30160
30161namespace llvm {
30162class Module;
30163class TargetOptions;
30164class LLVMContext;
30165class raw_fd_ostream;
30166class raw_pwrite_stream;
30167class raw_ostream;
30168} // namespace llvm
30169
30170namespace Halide {
30171
30172class Module;
30173struct Target;
30174
30175namespace Internal {
30176typedef llvm::raw_pwrite_stream LLVMOStream;
30177}
30178
30179/** Generate an LLVM module. */
30180std::unique_ptr<llvm::Module> compile_module_to_llvm_module(const Module &module, llvm::LLVMContext &context);
30181
30182/** Construct an llvm output stream for writing to files. */
30183std::unique_ptr<llvm::raw_fd_ostream> make_raw_fd_ostream(const std::string &filename);
30184
30185/** Compile an LLVM module to native targets (objects, native assembly). */
30186// @{
30187void compile_llvm_module_to_object(llvm::Module &module, Internal::LLVMOStream &out);
30188void compile_llvm_module_to_assembly(llvm::Module &module, Internal::LLVMOStream &out);
30189// @}
30190
30191/** Compile an LLVM module to LLVM targets (bitcode, LLVM assembly). */
30192// @{
30193void compile_llvm_module_to_llvm_bitcode(llvm::Module &module, Internal::LLVMOStream &out);
30194void compile_llvm_module_to_llvm_assembly(llvm::Module &module, Internal::LLVMOStream &out);
30195// @}
30196
30197/**
30198 * Concatenate the list of src_files into dst_file, using the appropriate
30199 * static library format for the given target (e.g., .a or .lib).
30200 * If deterministic is true, emit 0 for all GID/UID/timestamps, and 0644 for
30201 * all modes (equivalent to the ar -D option).
30202 */
30203void create_static_library(const std::vector<std::string> &src_files, const Target &target,
30204 const std::string &dst_file, bool deterministic = true);
30205} // namespace Halide
30206
30207#endif
30208#ifndef HALIDE_LLVM_RUNTIME_LINKER_H
30209#define HALIDE_LLVM_RUNTIME_LINKER_H
30210
30211/** \file
30212 * Support for linking LLVM modules that comprise the runtime.
30213 */
30214
30215#include <memory>
30216#include <string>
30217#include <vector>
30218
30219namespace llvm {
30220class GlobalValue;
30221class Module;
30222class LLVMContext;
30223class Triple;
30224} // namespace llvm
30225
30226namespace Halide {
30227
30228struct Target;
30229
30230namespace Internal {
30231
30232/** Return the llvm::Triple that corresponds to the given Halide Target */
30233llvm::Triple get_triple_for_target(const Target &target);
30234
30235/** Create an llvm module containing the support code for a given target. */
30236std::unique_ptr<llvm::Module> get_initial_module_for_target(Target, llvm::LLVMContext *, bool for_shared_jit_runtime = false, bool just_gpu = false);
30237
30238/** Create an llvm module containing the support code for ptx device. */
30239std::unique_ptr<llvm::Module> get_initial_module_for_ptx_device(Target, llvm::LLVMContext *c);
30240
30241/** Link a block of llvm bitcode into an llvm module. */
30242void add_bitcode_to_module(llvm::LLVMContext *context, llvm::Module &module,
30243 const std::vector<uint8_t> &bitcode, const std::string &name);
30244
30245/** Take the llvm::Module(s) in extra_modules (if any), add the runtime modules needed for the WASM JIT,
30246 * and link into a single llvm::Module. */
30247std::unique_ptr<llvm::Module> link_with_wasm_jit_runtime(llvm::LLVMContext *c, const Target &t,
30248 std::unique_ptr<llvm::Module> extra_module);
30249
30250} // namespace Internal
30251} // namespace Halide
30252
30253#endif
30254#ifndef HALIDE_LOOP_CARRY_H
30255#define HALIDE_LOOP_CARRY_H
30256
30257
30258namespace Halide {
30259namespace Internal {
30260
30261/** Reuse loads done on previous loop iterations by stashing them in
30262 * induction variables instead of redoing the load. If the loads are
30263 * predicated, the predicates need to match. Can be an optimization or
30264 * pessimization depending on how good the L1 cache is on the architecture
30265 * and how many memory issue slots there are. Currently only intended
30266 * for Hexagon. */
30267Stmt loop_carry(Stmt, int max_carried_values = 8);
30268
30269} // namespace Internal
30270} // namespace Halide
30271
30272#endif
30273#ifndef HALIDE_INTERNAL_LOWER_H
30274#define HALIDE_INTERNAL_LOWER_H
30275
30276/** \file
30277 *
30278 * Defines the function that generates a statement that computes a
30279 * Halide function using its schedule.
30280 */
30281
30282#include <string>
30283#include <vector>
30284
30285
30286namespace Halide {
30287
30288struct Target;
30289
30290namespace Internal {
30291
30292class Function;
30293class IRMutator;
30294
30295/** Given a vector of scheduled halide functions, create a Module that
30296 * evaluates it. Automatically pulls in all the functions f depends
30297 * on. Some stages of lowering may be target-specific. The Module may
30298 * contain submodules for computation offloaded to another execution
30299 * engine or API as well as buffers that are used in the passed in
30300 * Stmt. */
30301Module lower(const std::vector<Function> &output_funcs,
30302 const std::string &pipeline_name,
30303 const Target &t,
30304 const std::vector<Argument> &args,
30305 LinkageType linkage_type,
30306 const std::vector<Stmt> &requirements = std::vector<Stmt>(),
30307 bool trace_pipeline = false,
30308 const std::vector<IRMutator *> &custom_passes = std::vector<IRMutator *>());
30309
30310/** Given a halide function with a schedule, create a statement that
30311 * evaluates it. Automatically pulls in all the functions f depends
30312 * on. Some stages of lowering may be target-specific. Mostly used as
30313 * a convenience function in tests that wish to assert some property
30314 * of the lowered IR. */
30315Stmt lower_main_stmt(const std::vector<Function> &output_funcs,
30316 const std::string &pipeline_name,
30317 const Target &t,
30318 const std::vector<Stmt> &requirements = std::vector<Stmt>(),
30319 bool trace_pipeline = false,
30320 const std::vector<IRMutator *> &custom_passes = std::vector<IRMutator *>());
30321
30322void lower_test();
30323
30324} // namespace Internal
30325} // namespace Halide
30326
30327#endif
30328#ifndef HALIDE_LOWER_WARP_SHUFFLES_H
30329#define HALIDE_LOWER_WARP_SHUFFLES_H
30330
30331/** \file
30332 * Defines the lowering pass that injects CUDA warp shuffle
30333 * instructions to access storage outside of a GPULane loop.
30334 */
30335
30336
30337namespace Halide {
30338namespace Internal {
30339
30340/** Rewrite access to things stored outside the loop over GPU lanes to
30341 * use nvidia's warp shuffle instructions. */
30342Stmt lower_warp_shuffles(Stmt s);
30343
30344} // namespace Internal
30345} // namespace Halide
30346
30347#endif
30348/** \file
30349 * This file only exists to contain the front-page of the documentation
30350 */
30351
30352/** \mainpage Halide
30353 *
30354 * Halide is a programming language designed to make it easier to
30355 * write high-performance image processing code on modern
30356 * machines. Its front end is embedded in C++. Compiler
30357 * targets include x86/SSE, ARM v7/NEON, CUDA, Native Client,
30358 * OpenCL, and Metal.
30359 *
30360 * You build a Halide program by writing C++ code using objects of
30361 * type \ref Halide::Var, \ref Halide::Expr, and \ref Halide::Func,
30362 * and then calling \ref Halide::Func::compile_to_file to generate an
30363 * object file and header (good for deploying large routines), or
30364 * calling \ref Halide::Func::realize to JIT-compile and run the
30365 * pipeline immediately (good for testing small routines).
30366 *
30367 * To learn Halide, we recommend you start with the <a href=examples.html>tutorials</a>.
30368 *
30369 * You can also look in the test folder for many small examples that
30370 * use Halide's various features, and in the apps folder for some
30371 * larger examples that statically compile halide pipelines. In
30372 * particular check out local_laplacian, bilateral_grid, and
30373 * interpolate.
30374 *
30375 * Below are links to the documentation for the important classes in Halide.
30376 *
30377 * For defining, scheduling, and evaluating basic pipelines:
30378 *
30379 * Halide::Func, Halide::Stage, Halide::Var
30380 *
30381 * Our image data type:
30382 *
30383 * Halide::Buffer
30384 *
30385 * For passing around and reusing halide expressions:
30386 *
30387 * Halide::Expr
30388 *
30389 * For representing scalar and image parameters to pipelines:
30390 *
30391 * Halide::Param, Halide::ImageParam
30392 *
30393 * For writing functions that reduce or scatter over some domain:
30394 *
30395 * Halide::RDom
30396 *
30397 * For writing and evaluating functions that return multiple values:
30398 *
30399 * Halide::Tuple, Halide::Realization
30400 *
30401 */
30402
30403/**
30404 * \example tutorial/lesson_01_basics.cpp
30405 * \example tutorial/lesson_02_input_image.cpp
30406 * \example tutorial/lesson_03_debugging_1.cpp
30407 * \example tutorial/lesson_04_debugging_2.cpp
30408 * \example tutorial/lesson_05_scheduling_1.cpp
30409 * \example tutorial/lesson_06_realizing_over_shifted_domains.cpp
30410 * \example tutorial/lesson_07_multi_stage_pipelines.cpp
30411 * \example tutorial/lesson_08_scheduling_2.cpp
30412 * \example tutorial/lesson_09_update_definitions.cpp
30413 * \example tutorial/lesson_10_aot_compilation_generate.cpp
30414 * \example tutorial/lesson_10_aot_compilation_run.cpp
30415 * \example tutorial/lesson_11_cross_compilation.cpp
30416 * \example tutorial/lesson_12_using_the_gpu.cpp
30417 * \example tutorial/lesson_13_tuples.cpp
30418 * \example tutorial/lesson_14_types.cpp
30419 * \example tutorial/lesson_15_generators.cpp
30420 */
30421#ifndef HALIDE_MATLAB_OUTPUT_H
30422#define HALIDE_MATLAB_OUTPUT_H
30423
30424/** \file
30425 *
30426 * Provides an output function to generate a Matlab mex API compatible object file.
30427 */
30428
30429namespace llvm {
30430class Module;
30431class Function;
30432class Value;
30433} // namespace llvm
30434
30435namespace Halide {
30436namespace Internal {
30437
30438/** Add a mexFunction wrapper definition to the module, calling the
30439 * function with the name pipeline_name. Returns the mexFunction
30440 * definition. */
30441llvm::Function *define_matlab_wrapper(llvm::Module *module,
30442 llvm::Function *pipeline_argv_wrapper,
30443 llvm::Function *metadata_getter);
30444
30445} // namespace Internal
30446} // namespace Halide
30447
30448#endif
30449#ifndef HALIDE_INTERNAL_CACHING_H
30450#define HALIDE_INTERNAL_CACHING_H
30451
30452/** \file
30453 *
30454 * Defines the interface to the pass that injects support for
30455 * compute_cached roots.
30456 */
30457
30458#include <map>
30459#include <string>
30460
30461
30462namespace Halide {
30463namespace Internal {
30464
30465class Function;
30466
30467/** Transform pipeline calls for Funcs scheduled with memoize to do a
30468 * lookup call to the runtime cache implementation, and if there is a
30469 * miss, compute the results and call the runtime to store it back to
30470 * the cache.
30471 * Should leave non-memoized Funcs unchanged.
30472 */
30473Stmt inject_memoization(const Stmt &s, const std::map<std::string, Function> &env,
30474 const std::string &name,
30475 const std::vector<Function> &outputs);
30476
30477/** This should be called after Storage Flattening has added Allocation
30478 * IR nodes. It connects the memoization cache lookups to the Allocations
30479 * so they point to the buffers from the memoization cache and those buffers
30480 * are released when no longer used.
30481 * Should not affect allocations for non-memoized Funcs.
30482 */
30483Stmt rewrite_memoized_allocations(const Stmt &s, const std::map<std::string, Function> &env);
30484
30485} // namespace Internal
30486} // namespace Halide
30487
30488#endif
30489#ifndef HALIDE_MONOTONIC_H
30490#define HALIDE_MONOTONIC_H
30491
30492/** \file
30493 *
30494 * Methods for computing whether expressions are monotonic
30495 */
30496#include <iostream>
30497#include <string>
30498
30499
30500namespace Halide {
30501namespace Internal {
30502
30503/** Find the bounds of the derivative of an expression. */
30504ConstantInterval derivative_bounds(const Expr &e, const std::string &var,
30505 const Scope<ConstantInterval> &scope = Scope<ConstantInterval>::empty_scope());
30506
30507/**
30508 * Detect whether an expression is monotonic increasing in a variable,
30509 * decreasing, or unknown.
30510 */
30511enum class Monotonic { Constant,
30512 Increasing,
30513 Decreasing,
30514 Unknown };
30515Monotonic is_monotonic(const Expr &e, const std::string &var,
30516 const Scope<ConstantInterval> &scope = Scope<ConstantInterval>::empty_scope());
30517Monotonic is_monotonic(const Expr &e, const std::string &var, const Scope<Monotonic> &scope);
30518
30519/** Emit the monotonic class in human-readable form for debugging. */
30520std::ostream &operator<<(std::ostream &stream, const Monotonic &m);
30521
30522void is_monotonic_test();
30523
30524} // namespace Internal
30525} // namespace Halide
30526
30527#endif
30528#ifndef HALIDE_OFFLOAD_GPU_LOOPS_H
30529#define HALIDE_OFFLOAD_GPU_LOOPS_H
30530
30531/** \file
30532 * Defines a lowering pass to pull loops marked with
30533 * GPU device APIs to a separate module, and call them through the
30534 * appropriate host runtime module.
30535 */
30536
30537
30538namespace Halide {
30539
30540struct Target;
30541
30542namespace Internal {
30543
30544/** Pull loops marked with GPU device APIs to a separate
30545 * module, and call them through the appropriate host runtime module. */
30546Stmt inject_gpu_offload(const Stmt &s, const Target &host_target);
30547
30548} // namespace Internal
30549} // namespace Halide
30550
30551#endif
30552#ifndef HALIDE_PARALLEL_RVAR_H
30553#define HALIDE_PARALLEL_RVAR_H
30554
30555/** \file
30556 *
30557 * Method for checking if it's safe to parallelize an update
30558 * definition across a reduction variable.
30559 */
30560
30561#include <string>
30562
30563namespace Halide {
30564namespace Internal {
30565
30566class Definition;
30567
30568/** Returns whether or not Halide can prove that it is safe to
30569 * parallelize an update definition across a specific variable. If
30570 * this returns true, it's definitely safe. If this returns false, it
30571 * may still be safe, but Halide couldn't prove it.
30572 */
30573bool can_parallelize_rvar(const std::string &rvar,
30574 const std::string &func,
30575 const Definition &r);
30576
30577} // namespace Internal
30578} // namespace Halide
30579
30580#endif
30581#ifndef PARTITION_LOOPS_H
30582#define PARTITION_LOOPS_H
30583
30584/** \file
30585 * Defines a lowering pass that partitions loop bodies into three
30586 * to handle boundary conditions: A prologue, a simplified
30587 * steady-stage, and an epilogue.
30588 */
30589
30590
30591namespace Halide {
30592namespace Internal {
30593
30594/** Return true if an expression uses a likely tag that isn't captured
30595 * by an enclosing Select, Min, or Max. */
30596bool has_uncaptured_likely_tag(const Expr &e);
30597
30598/** Return true if an expression uses a likely tag. */
30599bool has_likely_tag(const Expr &e);
30600
30601/** Partitions loop bodies into a prologue, a steady state, and an
30602 * epilogue. Finds the steady state by hunting for use of clamped
30603 * ramps, or the 'likely' intrinsic. */
30604Stmt partition_loops(Stmt s);
30605
30606} // namespace Internal
30607} // namespace Halide
30608
30609#endif
30610#ifndef HALIDE_PREFETCH_H
30611#define HALIDE_PREFETCH_H
30612
30613/** \file
30614 * Defines the lowering pass that injects prefetch calls when prefetching
30615 * appears in the schedule.
30616 */
30617
30618#include <map>
30619#include <string>
30620#include <vector>
30621
30622namespace Halide {
30623
30624struct Target;
30625
30626namespace Internal {
30627
30628class Function;
30629struct PrefetchDirective;
30630struct Stmt;
30631
30632/** Inject placeholder prefetches to 's'. This placholder prefetch
30633 * does not have explicit region to be prefetched yet. It will be computed
30634 * during call to \ref inject_prefetch. */
30635Stmt inject_placeholder_prefetch(const Stmt &s, const std::map<std::string, Function> &env,
30636 const std::string &prefix,
30637 const std::vector<PrefetchDirective> &prefetches);
30638/** Compute the actual region to be prefetched and place it to the
30639 * placholder prefetch. Wrap the prefetch call with condition when
30640 * applicable. */
30641Stmt inject_prefetch(const Stmt &s, const std::map<std::string, Function> &env);
30642
30643/** Reduce a multi-dimensional prefetch into a prefetch of lower dimension
30644 * (max dimension of the prefetch is specified by target architecture).
30645 * This keeps the 'max_dim' innermost dimensions and adds loops for the rest
30646 * of the dimensions. If maximum prefetched-byte-size is specified (depending
30647 * on the architecture), this also adds an outer loops that tile the prefetches. */
30648Stmt reduce_prefetch_dimension(Stmt stmt, const Target &t);
30649
30650} // namespace Internal
30651} // namespace Halide
30652
30653#endif
30654#ifndef HALIDE_PROFILING_H
30655#define HALIDE_PROFILING_H
30656
30657/** \file
30658 * Defines the lowering pass that injects print statements when profiling is turned on.
30659 * The profiler will print out per-pipeline and per-func stats, such as total time
30660 * spent and heap/stack allocation information. To turn on the profiler, set
30661 * HL_TARGET/HL_JIT_TARGET flags to 'host-profile'.
30662 *
30663 * Output format:
30664 * \<pipeline_name\>
30665 * \<total time spent in this pipeline\> \<# of samples taken\> \<# of runs\> \<avg time/run\>
30666 * \<# of heap allocations\> \<peak heap allocation\>
30667 * \<func_name\> \<total time spent in this func\> \<percentage of time spent\>
30668 * (\<peak heap alloc by this func\> \<num of allocs\> \<average alloc size\> |
30669 * \<worst-case peak stack alloc by this func\>)?
30670 *
30671 * Sample output:
30672 * memory_profiler_mandelbrot
30673 * total time: 59.832336 ms samples: 43 runs: 1000 time/run: 0.059832 ms
30674 * heap allocations: 104000 peak heap usage: 505344 bytes
30675 * f0: 0.025673ms (42%)
30676 * mandelbrot: 0.006444ms (10%) peak: 505344 num: 104000 avg: 5376
30677 * argmin: 0.027715ms (46%) stack: 20
30678 */
30679#include <string>
30680
30681
30682namespace Halide {
30683namespace Internal {
30684
30685/** Take a statement representing a halide pipeline insert
30686 * high-resolution timing into the generated code (via spawning a
30687 * thread that acts as a sampling profiler); summaries of execution
30688 * times and counts will be logged at the end. Should be done before
30689 * storage flattening, but after all bounds inference.
30690 *
30691 */
30692Stmt inject_profiling(Stmt, const std::string &);
30693
30694} // namespace Internal
30695} // namespace Halide
30696
30697#endif
30698#ifndef HALIDE_PURIFY_INDEX_MATH_H
30699#define HALIDE_PURIFY_INDEX_MATH_H
30700
30701/** \file
30702 * Removes side-effects in integer math.
30703 */
30704
30705
30706namespace Halide {
30707namespace Internal {
30708
30709/** Bounds inference and related stages can lift integer bounds
30710 * expressions out of if statements that guard against those integer
30711 * expressions doing side-effecty things like dividing or modding by
30712 * zero. In those cases, if the lowering passes are functional, the
30713 * value resulting from the division or mod is evaluated but not
30714 * used. This mutator rewrites divs and mods in such expressions to
30715 * fail silently (evaluate to undef) when the denominator is zero.
30716 */
30717Expr purify_index_math(const Expr &);
30718
30719} // namespace Internal
30720} // namespace Halide
30721
30722#endif
30723#ifndef HALIDE_PYTHON_EXTENSION_GEN_H_
30724#define HALIDE_PYTHON_EXTENSION_GEN_H_
30725
30726#include <iostream>
30727#include <string>
30728#include <vector>
30729
30730namespace Halide {
30731
30732class Module;
30733
30734namespace Internal {
30735
30736struct LoweredArgument;
30737struct LoweredFunc;
30738
30739class PythonExtensionGen {
30740public:
30741 PythonExtensionGen(std::ostream &dest);
30742
30743 void compile(const Module &module);
30744
30745private:
30746 std::ostream &dest;
30747 std::vector<std::string> buffer_refs;
30748
30749 void compile(const LoweredFunc &f);
30750 void convert_buffer(const std::string &name, const LoweredArgument *arg);
30751 void release_buffers(const std::string &prefix);
30752};
30753
30754} // namespace Internal
30755} // namespace Halide
30756
30757#endif // HALIDE_PYTHON_EXTENSION_GEN_H_
30758#ifndef HALIDE_QUALIFY_H
30759#define HALIDE_QUALIFY_H
30760
30761/** \file
30762 *
30763 * Defines methods for prefixing names in an expression with a prefix string.
30764 */
30765#include <string>
30766
30767
30768namespace Halide {
30769namespace Internal {
30770
30771/** Prefix all variable names in the given expression with the prefix string. */
30772Expr qualify(const std::string &prefix, const Expr &value);
30773
30774} // namespace Internal
30775} // namespace Halide
30776
30777#endif
30778#ifndef HALIDE_RANDOM_H
30779#define HALIDE_RANDOM_H
30780
30781/** \file
30782 *
30783 * Defines deterministic random functions, and methods to redirect
30784 * front-end calls to random_float and random_int to use them. */
30785
30786#include <vector>
30787
30788
30789namespace Halide {
30790namespace Internal {
30791
30792/** Return a random floating-point number between zero and one that
30793 * varies deterministically based on the input expressions. */
30794Expr random_float(const std::vector<Expr> &);
30795
30796/** Return a random unsigned integer between zero and 2^32-1 that
30797 * varies deterministically based on the input expressions (which must
30798 * be integers or unsigned integers). */
30799Expr random_int(const std::vector<Expr> &);
30800
30801/** Convert calls to random() to IR generated by random_float and
30802 * random_int. Tags all calls with the variables in free_vars, and the
30803 * integer given as the last argument. */
30804Expr lower_random(const Expr &e, const std::vector<VarOrRVar> &free_vars, int tag);
30805
30806} // namespace Internal
30807} // namespace Halide
30808
30809#endif
30810#ifndef HALIDE_INTERNAL_REALIZATION_ORDER_H
30811#define HALIDE_INTERNAL_REALIZATION_ORDER_H
30812
30813/** \file
30814 *
30815 * Defines the lowering pass that determines the order in which
30816 * realizations are injected and groups functions with fused
30817 * computation loops.
30818 */
30819
30820#include <map>
30821#include <string>
30822#include <vector>
30823
30824namespace Halide {
30825namespace Internal {
30826
30827class Function;
30828
30829/** Given a bunch of functions that call each other, determine an
30830 * order in which to do the scheduling. This in turn influences the
30831 * order in which stages are computed when there's no strict
30832 * dependency between them. Currently just some arbitrary depth-first
30833 * traversal of the call graph. In addition, determine grouping of functions
30834 * with fused computation loops. The functions within the fused groups
30835 * are sorted based on realization order. There should not be any dependencies
30836 * among functions within a fused group. This pass will also populate the
30837 * 'fused_pairs' list in the function's schedule. Return a pair of
30838 * the realization order and the fused groups in that order.
30839 */
30840std::pair<std::vector<std::string>, std::vector<std::vector<std::string>>> realization_order(
30841 const std::vector<Function> &outputs, std::map<std::string, Function> &env);
30842
30843/** Given a bunch of functions that call each other, determine a
30844 * topological order which stays constant regardless of the schedule.
30845 * This ordering adheres to the producer-consumer dependencies, i.e. producer
30846 * will come before its consumers in that order */
30847std::vector<std::string> topological_order(
30848 const std::vector<Function> &outputs, const std::map<std::string, Function> &env);
30849
30850} // namespace Internal
30851} // namespace Halide
30852
30853#endif
30854#ifndef HALIDE_REBASE_LOOPS_TO_ZERO_H
30855#define HALIDE_REBASE_LOOPS_TO_ZERO_H
30856
30857/** \file
30858 * Defines the lowering pass that rewrites loop mins to be 0.
30859 */
30860
30861
30862namespace Halide {
30863namespace Internal {
30864
30865/** Rewrite the mins of most loops to 0. */
30866Stmt rebase_loops_to_zero(const Stmt &);
30867
30868} // namespace Internal
30869} // namespace Halide
30870
30871#endif
30872#ifndef HALIDE_INTERNAL_REGION_COSTS_H
30873#define HALIDE_INTERNAL_REGION_COSTS_H
30874
30875/** \file
30876 *
30877 * Defines RegionCosts - used by the auto scheduler to query the cost of
30878 * computing some function regions.
30879 */
30880
30881#include <map>
30882#include <string>
30883#include <vector>
30884
30885
30886namespace Halide {
30887namespace Internal {
30888
30889struct Cost {
30890 // Estimate of cycles spent doing arithmetic.
30891 Expr arith;
30892 // Estimate of bytes loaded.
30893 Expr memory;
30894
30895 Cost(int64_t arith, int64_t memory)
30896 : arith(arith), memory(memory) {
30897 }
30898 Cost(Expr arith, Expr memory)
30899 : arith(std::move(arith)), memory(std::move(memory)) {
30900 }
30901 Cost() = default;
30902
30903 inline bool defined() const {
30904 return arith.defined() && memory.defined();
30905 }
30906 void simplify();
30907
30908 friend std::ostream &operator<<(std::ostream &stream, const Cost &c) {
30909 stream << "[arith: " << c.arith << ", memory: " << c.memory << "]";
30910 return stream;
30911 }
30912};
30913
30914/** Auto scheduling component which is used to assign costs for computing a
30915 * region of a function or one of its stages. */
30916struct RegionCosts {
30917 /** An environment map which contains all functions in the pipeline. */
30918 std::map<std::string, Function> env;
30919 /** Realization order of functions in the pipeline. The first function to
30920 * be realized comes first. */
30921 std::vector<std::string> order;
30922 /** A map containing the cost of computing a value in each stage of a
30923 * function. The number of entries in the vector is equal to the number of
30924 * stages in the function. */
30925 std::map<std::string, std::vector<Cost>> func_cost;
30926 /** A map containing the types of all image inputs in the pipeline. */
30927 std::map<std::string, Type> inputs;
30928 /** A scope containing the estimated min/extent values of ImageParams
30929 * in the pipeline. */
30930 Scope<Interval> input_estimates;
30931
30932 /** Return the cost of producing a region (specified by 'bounds') of a
30933 * function stage (specified by 'func' and 'stage'). 'inlines' specifies
30934 * names of all the inlined functions. */
30935 Cost stage_region_cost(const std::string &func, int stage, const DimBounds &bounds,
30936 const std::set<std::string> &inlines = std::set<std::string>());
30937
30938 /** Return the cost of producing a region of a function stage (specified
30939 * by 'func' and 'stage'). 'inlines' specifies names of all the inlined
30940 * functions. */
30941 Cost stage_region_cost(const std::string &func, int stage, const Box &region,
30942 const std::set<std::string> &inlines = std::set<std::string>());
30943
30944 /** Return the cost of producing a region of function 'func'. This adds up the
30945 * costs of all stages of 'func' required to produce the region. 'inlines'
30946 * specifies names of all the inlined functions. */
30947 Cost region_cost(const std::string &func, const Box &region,
30948 const std::set<std::string> &inlines = std::set<std::string>());
30949
30950 /** Same as region_cost above but this computes the total cost of many
30951 * function regions. */
30952 Cost region_cost(const std::map<std::string, Box> &regions,
30953 const std::set<std::string> &inlines = std::set<std::string>());
30954
30955 /** Compute the cost of producing a single value by one stage of 'f'.
30956 * 'inlines' specifies names of all the inlined functions. */
30957 Cost get_func_stage_cost(const Function &f, int stage,
30958 const std::set<std::string> &inlines = std::set<std::string>()) const;
30959
30960 /** Compute the cost of producing a single value by all stages of 'f'.
30961 * 'inlines' specifies names of all the inlined functions. This returns a
30962 * vector of costs. Each entry in the vector corresponds to a stage in 'f'. */
30963 std::vector<Cost> get_func_cost(const Function &f,
30964 const std::set<std::string> &inlines = std::set<std::string>());
30965
30966 /** Computes the memory costs of computing a region (specified by 'bounds')
30967 * of a function stage (specified by 'func' and 'stage'). This returns a map
30968 * containing the costs incurred to access each of the functions required
30969 * to produce 'func'. */
30970 std::map<std::string, Expr>
30971 stage_detailed_load_costs(const std::string &func, int stage, DimBounds &bounds,
30972 const std::set<std::string> &inlines = std::set<std::string>());
30973
30974 /** Return a map containing the costs incurred to access each of the functions
30975 * required to produce a single value of a function stage. */
30976 std::map<std::string, Expr>
30977 stage_detailed_load_costs(const std::string &func, int stage,
30978 const std::set<std::string> &inlines = std::set<std::string>());
30979
30980 /** Same as stage_detailed_load_costs above but this computes the cost of a region
30981 * of 'func'. */
30982 std::map<std::string, Expr>
30983 detailed_load_costs(const std::string &func, const Box &region,
30984 const std::set<std::string> &inlines = std::set<std::string>());
30985
30986 /** Same as detailed_load_costs above but this computes the cost of many function
30987 * regions and aggregates them. */
30988 std::map<std::string, Expr>
30989 detailed_load_costs(const std::map<std::string, Box> &regions,
30990 const std::set<std::string> &inlines = std::set<std::string>());
30991
30992 /** Return the size of the region of 'func' in bytes. */
30993 Expr region_size(const std::string &func, const Box &region);
30994
30995 /** Return the size of the peak amount of memory allocated in bytes. This takes
30996 * the realization (topological) order of the function regions and the early
30997 * free mechanism into account while computing the peak footprint. */
30998 Expr region_footprint(const std::map<std::string, Box> &regions,
30999 const std::set<std::string> &inlined = std::set<std::string>());
31000
31001 /** Return the size of the input region in bytes. */
31002 Expr input_region_size(const std::string &input, const Box &region);
31003
31004 /** Return the total size of the many input regions in bytes. */
31005 Expr input_region_size(const std::map<std::string, Box> &input_regions);
31006
31007 /** Display the cost of each function in the pipeline. */
31008 void disp_func_costs();
31009
31010 /** Construct a region cost object for the pipeline. 'env' is a map of all
31011 * functions in the pipeline. 'order' is the realization order of functions
31012 * in the pipeline. The first function to be realized comes first. */
31013 RegionCosts(const std::map<std::string, Function> &env,
31014 const std::vector<std::string> &order);
31015};
31016
31017/** Return true if the cost of inlining a function is equivalent to the
31018 * cost of calling the function directly. */
31019bool is_func_trivial_to_inline(const Function &func);
31020
31021} // namespace Internal
31022} // namespace Halide
31023
31024#endif
31025#ifndef HALIDE_REMOVE_DEAD_ALLOCATIONS_H
31026#define HALIDE_REMOVE_DEAD_ALLOCATIONS_H
31027
31028/** \file
31029 * Defines the lowering pass that removes allocate and free nodes that
31030 * are not used.
31031 */
31032
31033
31034namespace Halide {
31035namespace Internal {
31036
31037/** Find Allocate/Free pairs that are never loaded from or stored to,
31038 * and remove them from the Stmt. This doesn't touch Realize/Call
31039 * nodes and so must be called after storage_flattening.
31040 */
31041Stmt remove_dead_allocations(const Stmt &s);
31042
31043} // namespace Internal
31044} // namespace Halide
31045
31046#endif
31047#ifndef HALIDE_REMOVE_EXTERN_LOOPS
31048#define HALIDE_REMOVE_EXTERN_LOOPS
31049
31050
31051/** \file
31052 * Defines a lowering pass that removes placeholder loops for extern stages.
31053 */
31054
31055namespace Halide {
31056namespace Internal {
31057
31058/** Removes placeholder loops for extern stages. */
31059Stmt remove_extern_loops(const Stmt &s);
31060
31061} // namespace Internal
31062} // namespace Halide
31063
31064#endif
31065#ifndef HALIDE_REMOVE_UNDEF
31066#define HALIDE_REMOVE_UNDEF
31067
31068
31069/** \file
31070 * Defines a lowering pass that elides stores that depend on unitialized values.
31071 */
31072
31073namespace Halide {
31074namespace Internal {
31075
31076/** Removes stores that depend on undef values, and statements that
31077 * only contain such stores. */
31078Stmt remove_undef(Stmt s);
31079
31080} // namespace Internal
31081} // namespace Halide
31082
31083#endif
31084#ifndef HALIDE_INTERNAL_SCHEDULE_FUNCTIONS_H
31085#define HALIDE_INTERNAL_SCHEDULE_FUNCTIONS_H
31086
31087/** \file
31088 *
31089 * Defines the function that does initial lowering of Halide Functions
31090 * into a loop nest using its schedule. The first stage of lowering.
31091 */
31092
31093#include <map>
31094#include <string>
31095#include <vector>
31096
31097
31098namespace Halide {
31099
31100struct Target;
31101
31102namespace Internal {
31103
31104class Function;
31105
31106/** Build loop nests and inject Function realizations at the
31107 * appropriate places using the schedule. Returns a flag indicating
31108 * whether memoization passes need to be run. */
31109Stmt schedule_functions(const std::vector<Function> &outputs,
31110 const std::vector<std::vector<std::string>> &fused_groups,
31111 const std::map<std::string, Function> &env,
31112 const Target &target,
31113 bool &any_memoized);
31114
31115} // namespace Internal
31116} // namespace Halide
31117
31118#endif
31119#ifndef HALIDE_INTERNAL_SELECT_GPU_API_H
31120#define HALIDE_INTERNAL_SELECT_GPU_API_H
31121
31122
31123/** \file
31124 * Defines a lowering pass that selects which GPU api to use for each
31125 * gpu for loop
31126 */
31127
31128namespace Halide {
31129
31130struct Target;
31131
31132namespace Internal {
31133
31134/** Replace for loops with GPU_Default device_api with an actual
31135 * device API depending on what's enabled in the target. Choose the
31136 * first of the following: opencl, cuda, openglcompute, opengl */
31137Stmt select_gpu_api(const Stmt &s, const Target &t);
31138
31139} // namespace Internal
31140} // namespace Halide
31141
31142#endif
31143#ifndef HALIDE_SIMPLIFY_H
31144#define HALIDE_SIMPLIFY_H
31145
31146/** \file
31147 * Methods for simplifying halide statements and expressions
31148 */
31149
31150
31151namespace Halide {
31152namespace Internal {
31153
31154/** Perform a a wide range of simplifications to expressions and
31155 * statements, including constant folding, substituting in trivial
31156 * values, arithmetic rearranging, etc. Simplifies across let
31157 * statements, so must not be called on stmts with dangling or
31158 * repeated variable names.
31159 */
31160// @{
31161Stmt simplify(const Stmt &, bool remove_dead_code = true,
31162 const Scope<Interval> &bounds = Scope<Interval>::empty_scope(),
31163 const Scope<ModulusRemainder> &alignment = Scope<ModulusRemainder>::empty_scope());
31164Expr simplify(const Expr &, bool remove_dead_code = true,
31165 const Scope<Interval> &bounds = Scope<Interval>::empty_scope(),
31166 const Scope<ModulusRemainder> &alignment = Scope<ModulusRemainder>::empty_scope());
31167// @}
31168
31169/** Attempt to statically prove an expression is true using the simplifier. */
31170bool can_prove(Expr e, const Scope<Interval> &bounds = Scope<Interval>::empty_scope());
31171
31172/** Simplify expressions found in a statement, but don't simplify
31173 * across different statements. This is safe to perform at an earlier
31174 * stage in lowering than full simplification of a stmt. */
31175Stmt simplify_exprs(const Stmt &);
31176
31177} // namespace Internal
31178} // namespace Halide
31179
31180#endif
31181#ifndef HALIDE_SIMPLIFY_CORRELATED_DIFFERENCES
31182#define HALIDE_SIMPLIFY_CORRELATED_DIFFERENCES
31183
31184
31185/** \file
31186 * Defines a simplification pass for handling differences of correlated expressions.
31187 */
31188
31189namespace Halide {
31190namespace Internal {
31191
31192/** Symbolic interval arithmetic can be extremely conservative in
31193 * cases where we analyze the difference between two correlated
31194 * expressions. For example, consider:
31195 *
31196 * for x in [0, 10]:
31197 * let y = x + 3
31198 * let z = y - x
31199 *
31200 * x lies within [0, 10]. Interval arithmetic will correctly determine
31201 * that y lies within [3, 13]. When z is encountered, it is treated as
31202 * a difference of two independent variables, and gives [3 - 10, 13 -
31203 * 0] = [-7, 13] instead of the tighter interval [3, 3]. It
31204 * doesn't understand that y and x are correlated.
31205 *
31206 * In practice, this problem causes problems for unrolling, and
31207 * arbitrarily-bad overconservative behavior in bounds inference
31208 * (e.g. https://github.com/halide/Halide/issues/3697 )
31209 *
31210 * The function below attempts to address this by walking the IR,
31211 * remembering whether each let variable is monotonic increasing,
31212 * decreasing, unknown, or constant w.r.t each loop var. When it
31213 * encounters a subtract node where both sides have the same
31214 * monotonicity it substitutes, solves, and attempts to generally
31215 * simplify as aggressively as possible to try to cancel out the
31216 * repeated dependence on the loop var. The same is done for addition
31217 * nodes with arguments of opposite monotonicity.
31218 *
31219 * Bounds inference is particularly sensitive to these false
31220 * dependencies, but removing false dependencies also helps other
31221 * lowering passes. E.g. if this simplification means a value no
31222 * longer depends on a loop variable, it can remain scalar during
31223 * vectorization of that loop, or we can lift it out as a loop
31224 * invariant, or it might avoid some of the complex paths in GPU
31225 * codegen that trigger when values depend on the block index
31226 * (e.g. warp shuffles).
31227 *
31228 * This pass is safe to use on code with repeated instances of the
31229 * same variable name (it must be, because we want to run it before
31230 * allocation bounds inference).
31231 */
31232Stmt simplify_correlated_differences(const Stmt &);
31233
31234} // namespace Internal
31235} // namespace Halide
31236
31237#endif
31238#ifndef SIMPLIFY_SPECIALIZATIONS_H
31239#define SIMPLIFY_SPECIALIZATIONS_H
31240
31241/** \file
31242 *
31243 * Defines pass that try to simplify the RHS/LHS of a function's definition
31244 * based on its specializations.
31245 */
31246
31247#include <map>
31248#include <string>
31249
31250
31251namespace Halide {
31252namespace Internal {
31253
31254class Function;
31255
31256/** Try to simplify the RHS/LHS of a function's definition based on its
31257 * specializations. */
31258void simplify_specializations(std::map<std::string, Function> &env);
31259
31260} // namespace Internal
31261} // namespace Halide
31262
31263#endif
31264#ifndef HALIDE_SKIP_STAGES
31265#define HALIDE_SKIP_STAGES
31266
31267#include <string>
31268#include <vector>
31269
31270
31271/** \file
31272 * Defines a pass that dynamically avoids realizing unnecessary stages.
31273 */
31274
31275namespace Halide {
31276namespace Internal {
31277
31278/** Avoid computing certain stages if we can infer a runtime condition
31279 * to check that tells us they won't be used. Does this by analyzing
31280 * all reads of each buffer allocated, and inferring some condition
31281 * that tells us if the reads occur. If the condition is non-trivial,
31282 * inject ifs that guard the production. */
31283Stmt skip_stages(Stmt s, const std::vector<std::string> &order);
31284
31285} // namespace Internal
31286} // namespace Halide
31287
31288#endif
31289#ifndef HALIDE_SLIDING_WINDOW_H
31290#define HALIDE_SLIDING_WINDOW_H
31291
31292/** \file
31293 *
31294 * Defines the sliding_window lowering optimization pass, which avoids
31295 * computing provably-already-computed values.
31296 */
31297
31298#include <map>
31299#include <string>
31300
31301
31302namespace Halide {
31303namespace Internal {
31304
31305class Function;
31306
31307/** Perform sliding window optimizations on a halide
31308 * statement. I.e. don't bother computing points in a function that
31309 * have provably already been computed by a previous iteration.
31310 */
31311Stmt sliding_window(const Stmt &s, const std::map<std::string, Function> &env);
31312
31313} // namespace Internal
31314} // namespace Halide
31315
31316#endif
31317#ifndef SOLVE_H
31318#define SOLVE_H
31319
31320/** Defines methods for manipulating and analyzing boolean expressions. */
31321
31322
31323namespace Halide {
31324namespace Internal {
31325
31326struct SolverResult {
31327 Expr result;
31328 bool fully_solved;
31329};
31330
31331/** Attempts to collect all instances of a variable in an expression
31332 * tree and place it as far to the left as possible, and as far up the
31333 * tree as possible (i.e. outside most parentheses). If the expression
31334 * is an equality or comparison, this 'solves' the equation. Returns a
31335 * pair of Expr and bool. The Expr is the mutated expression, and the
31336 * bool indicates whether there is a single instance of the variable
31337 * in the result. If it is false, the expression has only been partially
31338 * solved, and there are still multiple instances of the variable. */
31339SolverResult solve_expression(
31340 const Expr &e, const std::string &variable,
31341 const Scope<Expr> &scope = Scope<Expr>::empty_scope());
31342
31343/** Find the smallest interval such that the condition is either true
31344 * or false inside of it, but definitely false outside of it. Never
31345 * returns undefined Exprs, instead it uses variables called "pos_inf"
31346 * and "neg_inf" to represent positive and negative infinity. */
31347Interval solve_for_outer_interval(const Expr &c, const std::string &variable);
31348
31349/** Find the largest interval such that the condition is definitely
31350 * true inside of it, and might be true or false outside of it. */
31351Interval solve_for_inner_interval(const Expr &c, const std::string &variable);
31352
31353/** Take a conditional that includes variables that vary over some
31354 * domain, and convert it to a more conservative (less frequently
31355 * true) condition that doesn't depend on those variables. Formally,
31356 * the output expr implies the input expr.
31357 *
31358 * The condition may be a vector condition, in which case we also
31359 * 'and' over the vector lanes, and return a scalar result. */
31360Expr and_condition_over_domain(const Expr &c, const Scope<Interval> &varying);
31361
31362void solve_test();
31363
31364} // namespace Internal
31365} // namespace Halide
31366
31367#endif
31368#ifndef HALIDE_SPLIT_TUPLES_H
31369#define HALIDE_SPLIT_TUPLES_H
31370
31371#include <map>
31372
31373/** \file
31374 * Defines the lowering pass that breaks up Tuple-valued realization
31375 * and productions into several scalar-valued ones. */
31376
31377namespace Halide {
31378namespace Internal {
31379
31380class Function;
31381
31382/** Rewrite all tuple-valued Realizations, Provide nodes, and Call
31383 * nodes into several scalar-valued ones, so that later lowering
31384 * passes only need to think about scalar-valued productions. */
31385
31386Stmt split_tuples(const Stmt &s, const std::map<std::string, Function> &env);
31387
31388} // namespace Internal
31389} // namespace Halide
31390
31391#endif
31392#ifndef HALIDE_STMT_TO_HTML
31393#define HALIDE_STMT_TO_HTML
31394
31395/** \file
31396 * Defines a function to dump an HTML-formatted stmt to a file.
31397 */
31398
31399#include <string>
31400
31401namespace Halide {
31402
31403class Module;
31404
31405namespace Internal {
31406
31407struct Stmt;
31408
31409/**
31410 * Dump an HTML-formatted print of a Stmt to filename.
31411 */
31412void print_to_html(const std::string &filename, const Stmt &s);
31413
31414/** Dump an HTML-formatted print of a Module to filename. */
31415void print_to_html(const std::string &filename, const Module &m);
31416
31417} // namespace Internal
31418} // namespace Halide
31419
31420#endif
31421#ifndef HALIDE_STORAGE_FLATTENING_H
31422#define HALIDE_STORAGE_FLATTENING_H
31423
31424/** \file
31425 * Defines the lowering pass that flattens multi-dimensional storage
31426 * into single-dimensional array access
31427 */
31428
31429#include <map>
31430#include <string>
31431#include <vector>
31432
31433
31434namespace Halide {
31435
31436struct Target;
31437
31438namespace Internal {
31439
31440class Function;
31441
31442/** Take a statement with multi-dimensional Realize, Provide, and Call
31443 * nodes, and turn it into a statement with single-dimensional
31444 * Allocate, Store, and Load nodes respectively. */
31445Stmt storage_flattening(Stmt s,
31446 const std::vector<Function> &outputs,
31447 const std::map<std::string, Function> &env,
31448 const Target &target);
31449
31450} // namespace Internal
31451} // namespace Halide
31452
31453#endif
31454#ifndef HALIDE_STORAGE_FOLDING_H
31455#define HALIDE_STORAGE_FOLDING_H
31456
31457/** \file
31458 * Defines the lowering optimization pass that reduces large buffers
31459 * down to smaller circular buffers when possible
31460 */
31461#include <map>
31462#include <string>
31463
31464
31465namespace Halide {
31466namespace Internal {
31467
31468class Function;
31469
31470/** Fold storage of functions if possible. This means reducing one of
31471 * the dimensions module something for the purpose of storage, if we
31472 * can prove that this is safe to do. E.g consider:
31473 *
31474 \code
31475 f(x) = ...
31476 g(x) = f(x-1) + f(x)
31477 f.store_root().compute_at(g, x);
31478 \endcode
31479 *
31480 * We can store f as a circular buffer of size two, instead of
31481 * allocating space for all of it.
31482 */
31483Stmt storage_folding(const Stmt &s, const std::map<std::string, Function> &env);
31484
31485} // namespace Internal
31486} // namespace Halide
31487
31488#endif
31489#ifndef HALIDE_STRICTIFY_FLOAT_H
31490#define HALIDE_STRICTIFY_FLOAT_H
31491
31492/** \file
31493 * Defines a lowering pass to make all floating-point strict for all top-level Exprs.
31494 */
31495
31496#include <map>
31497#include <string>
31498
31499namespace Halide {
31500
31501struct Target;
31502
31503namespace Internal {
31504
31505class Function;
31506
31507/** Propagate strict_float intrinisics such that they immediately wrap
31508 * all floating-point expressions. This makes the IR nodes context
31509 * independent. If the Target::StrictFloat flag is specified in
31510 * target, starts in strict_float mode so all floating-point type
31511 * Exprs in the compilation will be marked with strict_float. Returns
31512 * whether any strict floating-point is used in any function in the
31513 * passed in env.
31514 */
31515bool strictify_float(std::map<std::string, Function> &env, const Target &t);
31516
31517} // namespace Internal
31518} // namespace Halide
31519
31520#endif
31521#ifndef HALIDE_SUBSTITUTE_H
31522#define HALIDE_SUBSTITUTE_H
31523
31524/** \file
31525 *
31526 * Defines methods for substituting out variables in expressions and
31527 * statements. */
31528
31529#include <map>
31530
31531
31532namespace Halide {
31533namespace Internal {
31534
31535/** Substitute variables with the given name with the replacement
31536 * expression within expr. This is a dangerous thing to do if variable
31537 * names have not been uniquified. While it won't traverse inside let
31538 * statements with the same name as the first argument, moving a piece
31539 * of syntax around can change its meaning, because it can cross lets
31540 * that redefine variable names that it includes references to. */
31541Expr substitute(const std::string &name, const Expr &replacement, const Expr &expr);
31542
31543/** Substitute variables with the given name with the replacement
31544 * expression within stmt. */
31545Stmt substitute(const std::string &name, const Expr &replacement, const Stmt &stmt);
31546
31547/** Substitute variables with names in the map. */
31548// @{
31549Expr substitute(const std::map<std::string, Expr> &replacements, const Expr &expr);
31550Stmt substitute(const std::map<std::string, Expr> &replacements, const Stmt &stmt);
31551// @}
31552
31553/** Substitute expressions for other expressions. */
31554// @{
31555Expr substitute(const Expr &find, const Expr &replacement, const Expr &expr);
31556Stmt substitute(const Expr &find, const Expr &replacement, const Stmt &stmt);
31557// @}
31558
31559/** Substitutions where the IR may be a general graph (and not just a
31560 * DAG). */
31561// @{
31562Expr graph_substitute(const std::string &name, const Expr &replacement, const Expr &expr);
31563Stmt graph_substitute(const std::string &name, const Expr &replacement, const Stmt &stmt);
31564Expr graph_substitute(const Expr &find, const Expr &replacement, const Expr &expr);
31565Stmt graph_substitute(const Expr &find, const Expr &replacement, const Stmt &stmt);
31566// @}
31567
31568/** Substitute in all let Exprs in a piece of IR. Doesn't substitute
31569 * in let stmts, as this may change the meaning of the IR (e.g. by
31570 * moving a load after a store). Produces graphs of IR, so don't use
31571 * non-graph-aware visitors or mutators on it until you've CSE'd the
31572 * result. */
31573// @{
31574Expr substitute_in_all_lets(const Expr &expr);
31575Stmt substitute_in_all_lets(const Stmt &stmt);
31576// @}
31577
31578} // namespace Internal
31579} // namespace Halide
31580
31581#endif
31582#ifndef HALIDE_THREAD_POOL_H
31583#define HALIDE_THREAD_POOL_H
31584
31585#include <condition_variable>
31586#include <future>
31587#include <mutex>
31588#include <queue>
31589#include <thread>
31590#include <utility>
31591
31592#ifdef _MSC_VER
31593#else
31594#include <unistd.h>
31595#endif
31596
31597/** \file
31598 * Define a simple thread pool utility that is modeled on the api of
31599 * C++11 std::async(); since implementation details of std::async
31600 * can vary considerably, with no control over thread spawning, this class
31601 * allows us to use the same model but with precise control over thread usage.
31602 *
31603 * A ThreadPool is created with a specific number of threads, which will never
31604 * vary over the life of the ThreadPool. (If created without a specific number
31605 * of threads, it will attempt to use threads == number-of-cores.)
31606 *
31607 * Each async request will go into a queue, and will be serviced by the next
31608 * available thread from the pool.
31609 *
31610 * The ThreadPool's dtor will block until all currently-executing tasks
31611 * to finish (but won't schedule any more).
31612 *
31613 * Note that this is a fairly simpleminded ThreadPool, meant for tasks
31614 * that are fairly coarse (e.g. different tasks in a test); it is specifically
31615 * *not* intended to be the underlying implementation for Halide runtime threads
31616 */
31617namespace Halide {
31618namespace Internal {
31619
31620template<typename T>
31621class ThreadPool {
31622 struct Job {
31623 std::function<T()> func;
31624 std::promise<T> result;
31625
31626 void run_unlocked(std::unique_lock<std::mutex> &unique_lock);
31627 };
31628
31629 // all fields are protected by this mutex.
31630 std::mutex mutex;
31631
31632 // Queue of Jobs.
31633 std::queue<Job> jobs;
31634
31635 // Broadcast whenever items are added to the Job queue.
31636 std::condition_variable wakeup_threads;
31637
31638 // Keep track of threads so they can be joined at shutdown
31639 std::vector<std::thread> threads;
31640
31641 // True if the pool is shutting down.
31642 bool shutting_down{false};
31643
31644 void worker_thread() {
31645 std::unique_lock<std::mutex> unique_lock(mutex);
31646 while (!shutting_down) {
31647 if (jobs.empty()) {
31648 // There are no jobs pending. Wait until more jobs are enqueued.
31649 wakeup_threads.wait(unique_lock);
31650 } else {
31651 // Grab the next job.
31652 Job cur_job = std::move(jobs.front());
31653 jobs.pop();
31654 cur_job.run_unlocked(unique_lock);
31655 }
31656 }
31657 }
31658
31659public:
31660 static size_t num_processors_online() {
31661#ifdef _WIN32
31662 char *num_cores = getenv("NUMBER_OF_PROCESSORS");
31663 return num_cores ? atoi(num_cores) : 8;
31664#else
31665 return sysconf(_SC_NPROCESSORS_ONLN);
31666#endif
31667 }
31668
31669 // Default to number of available cores if not specified otherwise
31670 ThreadPool(size_t desired_num_threads = num_processors_online()) {
31671 // This file doesn't depend on anything else in libHalide, so
31672 // we'll use assert, not internal_assert.
31673 assert(desired_num_threads > 0);
31674
31675 std::lock_guard<std::mutex> lock(mutex);
31676
31677 // Create all the threads.
31678 for (size_t i = 0; i < desired_num_threads; ++i) {
31679 threads.emplace_back([this] { worker_thread(); });
31680 }
31681 }
31682
31683 ~ThreadPool() {
31684 // Wake everyone up and tell them the party's over and it's time to go home
31685 {
31686 std::lock_guard<std::mutex> lock(mutex);
31687 shutting_down = true;
31688 wakeup_threads.notify_all();
31689 }
31690
31691 // Wait until they leave
31692 for (auto &t : threads) {
31693 t.join();
31694 }
31695 }
31696
31697 template<typename Func, typename... Args>
31698 std::future<T> async(Func func, Args... args) {
31699 std::lock_guard<std::mutex> lock(mutex);
31700
31701 Job job;
31702 // Don't use std::forward here: we never want args passed by reference,
31703 // since they will be accessed from an arbitrary thread.
31704 //
31705 // Some versions of GCC won't allow capturing variadic arguments in a lambda;
31706 //
31707 // job.func = [func, args...]() -> T { return func(args...); }; // Nope, sorry
31708 //
31709 // fortunately, we can use std::bind() to accomplish the same thing.
31710 job.func = std::bind(func, args...);
31711 jobs.emplace(std::move(job));
31712 std::future<T> result = jobs.back().result.get_future();
31713
31714 // Wake up our threads.
31715 wakeup_threads.notify_all();
31716
31717 return result;
31718 }
31719};
31720
31721template<typename T>
31722inline void ThreadPool<T>::Job::run_unlocked(std::unique_lock<std::mutex> &unique_lock) {
31723 unique_lock.unlock();
31724 T r = func();
31725 unique_lock.lock();
31726 result.set_value(std::move(r));
31727}
31728
31729template<>
31730inline void ThreadPool<void>::Job::run_unlocked(std::unique_lock<std::mutex> &unique_lock) {
31731 unique_lock.unlock();
31732 func();
31733 unique_lock.lock();
31734 result.set_value();
31735}
31736
31737} // namespace Internal
31738} // namespace Halide
31739
31740#endif // HALIDE_THREAD_POOL_H
31741#ifndef HALIDE_TRACING_H
31742#define HALIDE_TRACING_H
31743
31744/** \file
31745 * Defines the lowering pass that injects print statements when tracing is turned on
31746 */
31747
31748#include <map>
31749#include <string>
31750#include <vector>
31751
31752
31753namespace Halide {
31754
31755struct Target;
31756
31757namespace Internal {
31758
31759class Function;
31760
31761/** Take a statement representing a halide pipeline, inject calls to
31762 * tracing functions at interesting points, such as
31763 * allocations. Should be done before storage flattening, but after
31764 * all bounds inference. */
31765Stmt inject_tracing(Stmt, const std::string &pipeline_name,
31766 bool trace_pipeline,
31767 const std::map<std::string, Function> &env,
31768 const std::vector<Function> &outputs,
31769 const Target &Target);
31770
31771} // namespace Internal
31772} // namespace Halide
31773
31774#endif
31775#ifndef TRIM_NO_OPS_H
31776#define TRIM_NO_OPS_H
31777
31778/** \file
31779 * Defines a lowering pass that truncates loops to the region over
31780 * which they actually do something.
31781 */
31782
31783
31784namespace Halide {
31785namespace Internal {
31786
31787/** Truncate loop bounds to the region over which they actually do
31788 * something. For examples see test/correctness/trim_no_ops.cpp */
31789Stmt trim_no_ops(Stmt s);
31790
31791} // namespace Internal
31792} // namespace Halide
31793
31794#endif
31795#ifndef HALIDE_UNIFY_DUPLICATE_LETS_H
31796#define HALIDE_UNIFY_DUPLICATE_LETS_H
31797
31798/** \file
31799 * Defines the lowering pass that coalesces redundant let statements
31800 */
31801
31802
31803namespace Halide {
31804namespace Internal {
31805
31806/** Find let statements that all define the same value, and make later
31807 * ones just reuse the symbol names of the earlier ones. */
31808Stmt unify_duplicate_lets(const Stmt &s);
31809
31810} // namespace Internal
31811} // namespace Halide
31812
31813#endif
31814#ifndef HALIDE_UNIQUIFY_VARIABLE_NAMES
31815#define HALIDE_UNIQUIFY_VARIABLE_NAMES
31816
31817/** \file
31818 * Defines the lowering pass that renames all variables to have unique names.
31819 */
31820
31821
31822namespace Halide {
31823namespace Internal {
31824
31825/** Modify a statement so that every internally-defined variable name
31826 * is unique. This lets later passes assume syntactic equivalence is
31827 * semantic equivalence. */
31828Stmt uniquify_variable_names(const Stmt &s);
31829
31830void uniquify_variable_names_test();
31831
31832} // namespace Internal
31833} // namespace Halide
31834
31835#endif
31836#ifndef HALIDE_UNPACK_BUFFERS_H
31837#define HALIDE_UNPACK_BUFFERS_H
31838
31839/** \file
31840 * Defines the lowering pass that unpacks buffer arguments onto the symbol table
31841 */
31842
31843
31844namespace Halide {
31845namespace Internal {
31846
31847/** Creates let stmts for the various buffer components
31848 * (e.g. foo.extent.0) in any referenced concrete buffers or buffer
31849 * parameters. After this pass, the only undefined symbols should
31850 * scalar parameters and the buffers themselves (e.g. foo.buffer). */
31851Stmt unpack_buffers(Stmt s);
31852
31853} // namespace Internal
31854} // namespace Halide
31855
31856#endif
31857#ifndef HALIDE_UNROLL_LOOPS_H
31858#define HALIDE_UNROLL_LOOPS_H
31859
31860/** \file
31861 * Defines the lowering pass that unrolls loops marked as such
31862 */
31863
31864
31865namespace Halide {
31866namespace Internal {
31867
31868/** Take a statement with for loops marked for unrolling, and convert
31869 * each into several copies of the innermost statement. I.e. unroll
31870 * the loop. */
31871Stmt unroll_loops(const Stmt &);
31872
31873} // namespace Internal
31874} // namespace Halide
31875
31876#endif
31877#ifndef HALIDE_UNSAFE_PROMISES_H
31878#define HALIDE_UNSAFE_PROMISES_H
31879
31880/** \file
31881 * Defines the lowering pass that removes unsafe promises
31882 */
31883
31884
31885namespace Halide {
31886
31887struct Target;
31888
31889namespace Internal {
31890
31891/** Lower all unsafe promises into either assertions or unchecked
31892 code, depending on the target. */
31893Stmt lower_unsafe_promises(const Stmt &s, const Target &t);
31894
31895/** Lower all safe promises by just stripping them. This is a good
31896 * idea once no more lowering stages are going to use
31897 * boxes_touched. */
31898Stmt lower_safe_promises(const Stmt &s);
31899
31900} // namespace Internal
31901} // namespace Halide
31902
31903#endif
31904#ifndef HALIDE_VECTORIZE_LOOPS_H
31905#define HALIDE_VECTORIZE_LOOPS_H
31906
31907/** \file
31908 * Defines the lowering pass that vectorizes loops marked as such
31909 */
31910
31911
31912#include <map>
31913
31914namespace Halide {
31915
31916struct Target;
31917
31918namespace Internal {
31919
31920/** Take a statement with for loops marked for vectorization, and turn
31921 * them into single statements that operate on vectors. The loops in
31922 * question must have constant extent.
31923 */
31924Stmt vectorize_loops(const Stmt &s, const std::map<std::string, Function> &env, const Target &t);
31925
31926} // namespace Internal
31927} // namespace Halide
31928
31929#endif
31930#ifndef HALIDE_WASM_EXECUTOR_H
31931#define HALIDE_WASM_EXECUTOR_H
31932
31933/** \file
31934 *
31935 * Support for running Halide-compiled Wasm code in-process.
31936 * Bindings for parameters, extern calls, etc. are established and the
31937 * Wasm code is executed. Allows calls to realize to work
31938 * exactly as if native code had been run, but via a JavaScript/Wasm VM.
31939 * Currently, only the WABT interpreter is supported.
31940 */
31941
31942
31943namespace Halide {
31944
31945struct Target;
31946
31947namespace Internal {
31948
31949struct WasmModuleContents;
31950
31951/** Handle to compiled wasm code which can be called later. */
31952struct WasmModule {
31953 Internal::IntrusivePtr<WasmModuleContents> contents;
31954
31955 /** If the given target can be executed via the wasm executor, return true. */
31956 static bool can_jit_target(const Target &target);
31957
31958 /** Compile generated wasm code with a set of externs. */
31959 static WasmModule compile(
31960 const Module &module,
31961 const std::vector<Argument> &arguments,
31962 const std::string &fn_name,
31963 const std::map<std::string, JITExtern> &externs,
31964 const std::vector<JITModule> &extern_deps);
31965
31966 /** Run generated previously compiled wasm code with a set of arguments. */
31967 int run(const void **args);
31968};
31969
31970} // namespace Internal
31971} // namespace Halide
31972
31973#endif // HALIDE_WASM_EXECUTOR_H
31974#ifndef HALIDE_WRAP_CALLS_H
31975#define HALIDE_WRAP_CALLS_H
31976
31977/** \file
31978 *
31979 * Defines pass to replace calls to wrapped Functions with their wrappers.
31980 */
31981
31982#include <map>
31983#include <string>
31984
31985namespace Halide {
31986namespace Internal {
31987
31988class Function;
31989
31990/** Replace every call to wrapped Functions in the Functions' definitions with
31991 * call to their wrapper functions. */
31992std::map<std::string, Function> wrap_func_calls(const std::map<std::string, Function> &env);
31993
31994} // namespace Internal
31995} // namespace Halide
31996
31997#endif
31998
31999// Clean up macros used inside Halide headers
32000#undef user_assert
32001#undef user_error
32002#undef user_warning
32003#undef internal_error
32004#undef internal_assert
32005#undef halide_runtime_error
32006#endif // HALIDE_H
32007